=== PAGE: https://www.union.ai/docs/v2/flyte ===
# Documentation
Welcome to the documentation.
## Subpages
- **Flyte**
- **Tutorials**
- **Integrations**
- **Reference**
- **Community**
- **Release notes**
- **Platform deployment**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide ===
# Flyte
Flyte is a free and open source platform that provides a full suite of powerful features for orchestrating AI workflows.
Flyte empowers AI development teams to rapidly ship high-quality code to production by offering optimized performance, unparalleled resource efficiency, and a delightful workflow authoring experience.
You deploy and manage Flyte yourself, on your own cloud infrastructure.
> [!NOTE]
> These are the Flyte **2.0** docs.
> To switch to [version 1.0](/docs/v1/flyte/) or to the commercial product, [**Union.ai**](/docs/v2/byoc/), use the selectors above.
>
> This documentation for open-source Flyte is maintained by Union.ai.
> **π Note**
>
> Want to try Flyte without installing anything? [Try Flyte 2 in your browser](https://flyte2intro.apps.demo.hosted.unionai.cloud/).
### **From Flyte 1 to 2**
Flyte 2 represents a fundamental shift in how AI workflows are written and executed. Learn
more in this section.
### **Quickstart**
Install Flyte 2, configure your local IDE, create and run your first task, and inspect the results in 2 minutes.
## Subpages
- **Overview**
- **Quickstart**
- **Core concepts**
- **Running locally**
- **Connecting to a cluster**
- **Projects and domains**
- **Basic project: RAG**
- **Advanced project: LLM reporting agent**
- **From Flyte 1 to 2**
- **Configure tasks**
- **Build tasks**
- **Run and deploy tasks**
- **Scale your runs**
- **Configure apps**
- **Build apps**
- **Serve and deploy apps**
- **Build an agent**
- **Sandboxing**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/overview ===
# Overview
In this guide we cover how to build AI applications, data pipelines, and ML workflows using the Flyte 2 SDK.
Programs written using the Flyte 2 SDK can run on either a Union.ai or Flyte OSS back-end. This guide applies to both.
## Pure Python, no DSL
Flyte lets you write workflows in standard Pythonβno domain-specific language, no special syntax, no restrictions.
Your "workflow" is simply a task that calls other tasks:
```python
@env.task()
async def my_workflow(data: list[str]) -> list[str]:
results = []
for item in data:
if should_process(item):
result = await process_item(item)
results.append(result)
return results
```
You can use everything Python offers:
- **Loops and conditionals** β standard `for`, `while`, `if-elif-else`
- **Error handling** β `try/except` blocks work as expected
- **Async/await** β native Python concurrency model
- **Any library** β import and use whatever you need
This means no learning curve beyond Python itself, and no fighting a DSL when your requirements don't fit its constraints.
## Durability
Every task execution in Flyte is automatically persisted. Inputs, outputs, and intermediate results are stored in an object store, giving you:
- **Full observability** β see exactly what data flowed through each step
- **Audit trail** β track what ran, when, and with what parameters
- **Data lineage** β trace outputs back to their inputs
This persistence happens automatically. You don't need to add logging or manually save stateβFlyte handles it.
## Reproducibility
Flyte ensures that runs can be reproduced exactly:
- **Deterministic execution** β same inputs produce same outputs
- **Caching** β task results are cached and reused when inputs match
- **Versioned containers** β code runs in the same environment every time
Caching is configurable per task:
```python
@env.task(cache="auto")
async def expensive_computation(data: str) -> str:
# This result will be cached and reused for identical inputs
...
```
When you rerun a workflow, Flyte serves cached results for unchanged tasks rather than recomputing them.
## Recoverability
When something fails, Flyte doesn't make you start over. Failed workflows can resume from where they left off:
- **Completed tasks are preserved** β successful outputs remain cached
- **Retry from failure point** β no need to re-execute what already succeeded
- **Fine-grained checkpoints** β the `@flyte.trace` decorator creates checkpoints within tasks
This reduces wasted compute and speeds up debugging. When a task fails after hours of prior computation, you fix the issue and continueβnot restart.
## Built for scale
Flyte handles the hard parts of distributed execution:
- **Parallel execution** β express parallelism with `asyncio.gather()`, Flyte handles the rest
- **Dynamic workflows** β construct workflows based on runtime data, not just static definitions
- **Fast scheduling** β reusable containers achieve millisecond-level task startup
- **Resource management** β specify CPU, memory, and GPU requirements per task
## What this means in practice
Consider a data pipeline that processes thousands of files, trains a model, and deploys it:
- If file processing fails on item 847, you fix the issue and resume from item 847
- If training succeeds, but deployment fails, you redeploy without retraining
- If you rerun next week with the same data, cached results skip redundant computation
- If you need to audit what happened, every step is recorded
Flyte gives you the flexibility of Python scripts with the reliability of a production system.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/quickstart ===
# Quickstart
Let's get you up and running with your first workflow on your local machine.
> **π Note**
>
> Want to try Flyte without installing anything? [Try Flyte 2 in your browser](https://flyte2intro.apps.demo.hosted.unionai.cloud/).
## What you'll need
- Python 3.10+ in a virtual environment
## Install the SDK
Install the `flyte` package:
```bash
pip install 'flyte[tui]'
```
> **π Note**
>
> We also install the `tui` extra to enable the terminal user interface.
Verify it worked:
```bash
flyte --version
```
Output:
```bash
Flyte SDK version: 2.*.*
```
## Configure
Create a config file for local execution. Runs will be persisted locally in a SQLite database.
```bash
flyte create config --local-persistence
```
This creates `.flyte/config.yaml` in your current directory. See [Setting up a configuration file](./connecting-to-a-cluster#configuration-file) for more options.
> **π Note**
>
> Run `flyte get config` to check which configuration is currently active.
## Write your first workflow
Create `hello.py`:
```python
# hello.py
import flyte
# The `hello_env` TaskEnvironment is assigned to the variable `env`.
# It is then used in the `@env.task` decorator to define tasks.
# The environment groups configuration for all tasks defined within it.
env = flyte.TaskEnvironment(name="hello_env")
# We use the `@env.task` decorator to define a task called `fn`.
@env.task
def fn(x: int) -> int: # Type annotations are required
slope, intercept = 2, 5
return slope * x + intercept
# We also use the `@env.task` decorator to define another task called `main`.
# This is the is the entrypoint task of the workflow.
# It calls the `fn` task defined above multiple times using `flyte.map`.
@env.task
def main(x_list: list[int] = list(range(10))) -> float:
y_list = list(flyte.map(fn, x_list)) # flyte.map is like Python map, but runs in parallel.
y_mean = sum(y_list) / len(y_list)
return y_mean
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/hello.py*
Here's what's happening:
- **`TaskEnvironment`** specifies configuration for your tasks (container image, resources, etc.)
- **`@env.task`** turns Python functions into tasks that run remotely
- Both tasks share the same `env`, so they'll have identical configurations
## Run it
Create a project directory and place your files there:
CODE4
> [!WARNING]
> Do not run `flyte run` from your home directory. Flyte packages the current directory when running remotely, so running from `$HOME` would attempt to bundle your entire home folder. Always work from a dedicated project directory.
Run the workflow:
CODE5
This executes the workflow locally on your machine.
## See the results
You can see the run in the TUI by running:
CODE6
The TUI will open into the explorer view

To navigate to the run details, double-click it or press `Enter` to view the run details.

## Next steps
Now that you've run your first workflow:
- [**Core concepts**](./core-concepts/_index): Understand the core concepts of Flyte programming
- [**Running locally**](./running-locally): Learn about the TUI, caching, and other features that work locally
- [**Connecting to a cluster**](./connecting-to-a-cluster): Configure your environment for remote execution
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts ===
# Core concepts
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
Now that you've completed the **Quickstart**, let's explore Flyte's core concepts through working examples.
By the end of this section, you'll understand:
- **TaskEnvironment**: The container configuration that defines where and how your code runs
- **Tasks**: Python functions that execute remotely in containers
- **Runs and Actions**: How Flyte tracks and manages your executions
- **Apps**: Long-running services for APIs, dashboards, and inference endpoints
Each concept is introduced with a practical example you can run yourself.
## How Flyte works
When you run code with Flyte, here's what happens:
1. You define a **TaskEnvironment** that specifies the container image and resources
2. You decorate Python functions with `@env.task` to create **tasks**
3. When you execute a task, Flyte creates a **run** that tracks the execution
4. Each task execution within a run is an **action**
Let's explore each of these in detail.
## Subpages
- **Core concepts > TaskEnvironment**
- **Core concepts > Tasks**
- **Core concepts > Runs and actions**
- **Core concepts > Apps**
- **Core concepts > Key capabilities**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts/task-environment ===
# TaskEnvironment
A `TaskEnvironment` defines the hardware and software environment where your tasks run. Think of it as the container configuration for your code.
## A minimal example
Here's the simplest possible TaskEnvironment:
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def hello() -> str:
return "Hello from Flyte!"
```
With just a `name`, you get Flyte's default container image and resource allocation. This is enough for simple tasks that only need Python and the Flyte SDK.
## What TaskEnvironment controls
A TaskEnvironment specifies two things:
**Hardware environment** - The compute resources allocated to each task:
- CPU cores
- Memory
- GPU type and count
**Software environment** - The container image your code runs in:
- Base image (Python version, OS)
- Installed packages and dependencies
- Environment variables
## Configuring resources
Use the `limits` parameter to specify compute resources:
```python
env = flyte.TaskEnvironment(
name="compute_heavy",
limits=flyte.Resources(cpu="4", mem="16Gi"),
)
```
For GPU workloads:
```python
env = flyte.TaskEnvironment(
name="gpu_training",
limits=flyte.Resources(cpu="8", mem="32Gi", gpu="1"),
accelerator=flyte.GPUAccelerator.NVIDIA_A10G,
)
```
## Configuring container images
For tasks that need additional Python packages, specify a custom image:
```python
image = flyte.Image.from_debian_base().with_pip_packages("pandas", "scikit-learn")
env = flyte.TaskEnvironment(
name="ml_env",
image=image,
)
```
See [Container images](../task-configuration/container-images) for detailed image configuration options.
## Multiple tasks, one environment
All tasks decorated with the same `@env.task` share that environment's configuration:
```python
env = flyte.TaskEnvironment(
name="data_processing",
limits=flyte.Resources(cpu="2", mem="8Gi"),
)
@env.task
def load_data(path: str) -> dict:
# Runs with 2 CPU, 8Gi memory
...
@env.task
def transform_data(data: dict) -> dict:
# Also runs with 2 CPU, 8Gi memory
...
```
This is useful when multiple tasks have similar requirements.
## Multiple environments
When tasks have different requirements, create separate environments:
```python
light_env = flyte.TaskEnvironment(
name="light",
limits=flyte.Resources(cpu="1", mem="2Gi"),
)
heavy_env = flyte.TaskEnvironment(
name="heavy",
limits=flyte.Resources(cpu="8", mem="32Gi"),
)
@light_env.task
def preprocess(data: str) -> str:
# Light processing
...
@heavy_env.task
def train_model(data: str) -> dict:
# Resource-intensive training
...
```
## Next steps
Now that you understand TaskEnvironments, let's look at how to define [tasks](./tasks) that run inside them.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts/tasks ===
# Tasks
A task is a Python function that runs remotely in a container. You create tasks by decorating functions with `@env.task`.
## Defining a task
Here's a simple task:
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def greet(name: str) -> str:
return f"Hello, {name}!"
```
The `@env.task` decorator tells Flyte to run this function in a container configured by `env`.
## Type hints are required
Flyte uses type hints to understand your data and serialize it between tasks:
```python
@env.task
def process_numbers(values: list[int]) -> int:
return sum(values)
```
Supported types include:
- Primitives: `int`, `float`, `str`, `bool`
- Collections: `list`, `dict`, `tuple`
- DataFrames: `pandas.DataFrame`, `polars.DataFrame`
- Files: `flyte.File`, `flyte.Directory`
- Custom: dataclasses, Pydantic models
See [Data classes and structures](../task-programming/dataclasses-and-structures) for complex types.
## Tasks calling tasks
In Flyte 2, tasks can call other tasks directly. The called task runs in its own container:
```python
@env.task
def fetch_data(url: str) -> dict:
# Runs in container 1
...
@env.task
def process_data(url: str) -> str:
data = fetch_data(url) # Calls fetch_data, runs in container 2
return transform(data)
```
This is how you build workflows in Flyte 2. There's no separate `@workflow` decorator - just tasks calling tasks.
## The top-level task
The task you execute directly is the "top-level" or "driver" task. It orchestrates other tasks:
```python
@env.task
def step_one(x: int) -> int:
return x * 2
@env.task
def step_two(x: int) -> int:
return x + 10
@env.task
def pipeline(x: int) -> int:
a = step_one(x) # Run step_one
b = step_two(a) # Run step_two with result
return b
```
When you run `pipeline`, it becomes the top-level task and orchestrates `step_one` and `step_two`.
## Running tasks locally
For quick testing, you can call a task like a regular function:
```python
# Direct call - runs locally, not in a container
result = greet("World")
print(result) # "Hello, World!"
```
This bypasses Flyte entirely and is useful for debugging logic. However, local calls don't track data, use remote resources, or benefit from Flyte's features.
## Running tasks remotely
To run a task on your Flyte backend:
```python
import flyte
flyte.init_from_config()
result = flyte.run(greet, name="World")
print(result) # "Hello, World!"
```
Or from the command line:
```bash
flyte run my_script.py greet --name World
```
This sends your code to the Flyte backend, runs it in a container, and returns the result.
## Next steps
Now that you can define and run tasks, let's understand how Flyte tracks executions with [runs and actions](./runs-and-actions).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts/runs-and-actions ===
# Runs and actions
When you execute a task on Flyte, the system creates a **run** to track it. Each individual task execution within that run is an **action**. Understanding this hierarchy helps you navigate the UI and debug your workflows.
## What is a run?
A **run** is the execution of a task that you directly initiate, plus all its descendant task executions, considered as a single unit.
When you execute:
```bash
flyte run my_script.py pipeline --x 5
```
Flyte creates a run for `pipeline`. If `pipeline` calls other tasks, those executions are part of the same run.
## What is an action?
An **action** is the execution of a single task, considered independently. A run consists of one or more actions.
Consider this workflow:
```python
@env.task
def step_one(x: int) -> int:
return x * 2
@env.task
def step_two(x: int) -> int:
return x + 10
@env.task
def pipeline(x: int) -> int:
a = step_one(x)
b = step_two(a)
return b
```
When you run `pipeline(5)`:
- **1 run** is created for the entire execution
- **3 actions** are created: one for `pipeline`, one for `step_one`, one for `step_two`
## Runs vs actions in practice
| Concept | What it represents | In the UI |
|---------|-------------------|-----------|
| **Run** | Complete execution initiated by user | Runs list, top-level view |
| **Action** | Single task execution | Individual task details, logs |
For details on how to run tasks locally and remotely, see [Tasks](./tasks#running-tasks-locally).
## Viewing runs in the UI
After running a task remotely, click the URL in the output to see your run in the UI:
```bash
flyte run my_script.py pipeline --x 5
```
Output:
```bash
abc123xyz
https://my-instance.example.com/v2/runs/project/my-project/domain/development/abc123xyz
Run 'a0' completed successfully.
```
In the UI, you can:
- See the overall run status and duration
- Navigate to individual actions
- View inputs and outputs for each task
- Access logs for debugging
- See the execution graph
## Understanding the execution graph
The UI shows how tasks relate to each other:
```
pipeline (action)
βββ step_one (action)
βββ step_two (action)
```
Each box is an action. Arrows show data flow between tasks. This visualization helps you understand complex workflows and identify bottlenecks.
## Checking run status
From the command line:
```bash
flyte get run
```
From Python:
```python
import flyte
flyte.init_from_config()
run = flyte.run(pipeline, x=5)
# The run object has status information
print(run.status)
```
## Next steps
You now understand tasks and how Flyte tracks their execution. Next, let's learn about [apps](./introducing-apps) - Flyte's approach to long-running services.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts/introducing-apps ===
# Apps
Now that you understand tasks, let's learn about apps - Flyte's way of running long-lived services.
## Tasks vs apps
You've already learned about **tasks**: Python functions that run to completion in containers. Tasks are great for data processing, training, and batch operations.
**Apps** are different. An app is a long-running service that stays active and handles requests over time. Apps are ideal for:
- REST APIs and webhooks
- Model inference endpoints
- Interactive dashboards
- Real-time data services
| Aspect | Task | App |
|--------|------|-----|
| Lifecycle | Runs once, then exits | Stays running indefinitely |
| Invocation | Called with inputs, returns outputs | Receives HTTP requests |
| Use case | Batch processing, training | APIs, inference, dashboards |
| Durability | Inputs/outputs stored, can resume | Stateless request handling |
## AppEnvironment
Just as tasks use `TaskEnvironment`, apps use `AppEnvironment` to configure their runtime.
An `AppEnvironment` specifies:
- **Hardware**: CPU, memory, GPU allocation
- **Software**: Container image with dependencies
- **App-specific settings**: Ports, scaling, authentication
Here's a simple example:
```python
import flyte
from flyte.app.extras import FastAPIAppEnvironment
env = FastAPIAppEnvironment(
name="my-app",
image=flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn"),
limits=flyte.Resources(cpu="1", mem="2Gi"),
)
```
## A hello world app
Let's create a minimal FastAPI app to see how this works.
First, create `hello_app.py`:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "uvicorn",
# ]
# ///
"""A simple "Hello World" FastAPI app example for serving."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# Define a simple FastAPI application
app = FastAPI(
title="Hello World API",
description="A simple FastAPI application",
version="1.0.0",
)
# Create an AppEnvironment for the FastAPI app
env = FastAPIAppEnvironment(
name="hello-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# Define API endpoints
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Serving this script will deploy and serve the app on your Union/Flyte instance.
if __name__ == "__main__":
# Initialize Flyte from a config file.
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
# Serve the app remotely.
app_instance = flyte.serve(env)
# Print the app URL.
print(app_instance.url)
print("App 'hello-app' is now serving.")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/serving/hello_app.py*
### Understanding the code
- **`FastAPI()`** creates the web application with its endpoints
- **`FastAPIAppEnvironment`** configures the container and resources
- **`@app.get("/")`** defines an HTTP endpoint that returns a greeting
- **`flyte.serve()`** deploys and starts the app on your Flyte backend
### Serving the app
With your config file in place, serve the app:
```bash
flyte serve hello_app.py env
```
Or run the Python file directly (which calls `flyte.serve()` in the main block):
```bash
python hello_app.py
```
You'll see output like:
```output
https://my-instance.flyte.com/v2/domain/development/project/my-project/apps/hello-app
App 'hello-app' is now serving.
```
Click the link to view your app in the UI. You can find the app URL there, or visit `/docs` for FastAPI's interactive API documentation.
## When to use apps vs tasks
Use **tasks** when:
- Processing takes seconds to hours
- You need durability (inputs/outputs tracked)
- Work is triggered by events or schedules
- Results need to be cached or resumed
Use **apps** when:
- Responses must be fast (milliseconds)
- You're serving an API or dashboard
- Users interact in real-time
- You need a persistent endpoint
## Common patterns
**Model serving with FastAPI**: Train a model with a Flyte pipeline, then serve predictions from it. During local development, the app loads the model from a local file. When deployed remotely, Flyte's `Parameter` system automatically resolves the model from the latest training run output. See [FastAPI app](../build-apps/fastapi-app) for the full example.
**Agent UI with Gradio**: Build an interactive UI that kicks off agent runs using `flyte.with_runcontext()`. A single `RUN_MODE` environment variable controls the deployment progression: fully local (rapid iteration), local UI with remote task execution (cluster compute), or fully remote (production). See [Build apps](../build-apps/_index) for details.
## Next steps
You now understand the core building blocks of Flyte:
- **TaskEnvironment** and **AppEnvironment** configure where code runs
- **Tasks** are functions that execute and complete
- **Apps** are long-running services
- **Runs** and **Actions** track executions
Before diving deeper, check out [Key capabilities](./key-capabilities) for an overview of what Flyte can doβfrom parallelism and caching to LLM serving and error recovery.
Then head to [Basic project](../basic-project) to build an end-to-end ML system with training tasks and a serving app.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/core-concepts/key-capabilities ===
# Key capabilities
Now that you understand the core concepts -- `TaskEnvironment`, tasks, runs, and apps -- here's an overview of what Flyte can do. Each capability is covered in detail later in the documentation.
## Environment and resources
Configure how and where your code runs.
- **Multiple environments**: Create separate configurations for different use cases (dev, prod, GPU vs CPU)
β [Multiple environments](../task-configuration/multiple-environments)
- **Resource specification**: Request specific CPU, memory, GPU, and storage for your tasks
β [Resources](../task-configuration/resources)
## Deployment
Get your code running remotely.
- **Code packaging**: Your local code is automatically bundled and deployed to remote execution
β [Packaging](../task-deployment/packaging)
- **Local testing**: Test tasks locally before deploying with `flyte run --local`
β [How task run works](../task-deployment/how-task-run-works)
## Data handling
Pass data efficiently between tasks.
- **Files and directories**: Pass large files and directories between tasks using `flyte.io.File` and `flyte.io.Dir`
β [Files and directories](../task-programming/files-and-directories)
- **DataFrames**: Work with pandas, Polars, and other DataFrame types natively
β [DataFrames](../task-programming/dataframes)
## Parallelism and composition
Scale out and compose workflows.
- **Fanout parallelism**: Process items in parallel using `flyte.map` or `asyncio.gather`
β [Fanout](../task-programming/fanout)
- **Remote tasks**: Call previously deployed tasks from within your workflows
β [Remote tasks](../task-programming/remote-tasks)
## Security and automation
Manage credentials and automate execution.
- **Secrets**: Inject API keys, passwords, and other credentials securely into tasks
β [Secrets](../task-configuration/secrets)
- **Triggers**: Schedule tasks on a cron schedule or trigger them from external events
β [Triggers](../task-configuration/triggers)
- **Webhooks**: Build APIs that trigger task execution from external systems
β [App usage patterns](../build-apps/app-usage-patterns)
## Durability and reliability
Handle failures and avoid redundant work.
- **Error handling**: Catch failures and retry with different resources (e.g., more memory)
β [Error handling](../task-programming/error-handling)
- **Retries and timeouts**: Configure automatic retries and execution time limits
β [Retries and timeouts](../task-configuration/retries-and-timeouts)
- **Caching**: Add `cache="auto"` to any task and Flyte stores its outputs keyed on task name and inputs. Same inputs means instant results with no recomputation. This speeds up your development loop: skip re-downloading data, avoid replaying earlier steps in agentic chains, or bypass any expensive computation while you iterate.
β [Caching](../task-configuration/caching)
```python
@env.task(cache="auto")
async def load_data(data_dir: str = "./data") -> str:
"""Downloads once, then returns instantly on subsequent runs."""
# ... expensive download ...
return data_dir
```
- **Traces**: Use `@flyte.trace` to get visibility into the internal steps of a task without the overhead of making each step a separate task. Traced functions show up as child nodes under their parent task, each with their own timing, inputs, and outputs. This is particularly useful for AI agents where you want to see which tools were called.
β [Traces](../task-programming/traces)
```python
@flyte.trace
async def search(query: str) -> str:
"""Shows up as a child node under the parent task."""
return await do_search(query)
@env.task
async def agent(request: str) -> str:
results = await search(request) # Traced
answer = await summarize(results) # Also traced if decorated
return answer
```
- **Reports**: Add `report=True` to a task and it can generate an HTML report (charts, tables, images) saved alongside the task output. Combined with caching and persisted inputs/outputs, reports act as lightweight experiment trackingβeach run produces a self-contained HTML file you can compare across runs and share with your team.
β [Reports](../task-programming/reports)
```python
import flyte.report
@env.task(report=True)
async def evaluate(model_file: File, test_data: str) -> str:
# ... run evaluation ...
await flyte.report.replace.aio(
f"
Training Report
"
f"
Test Results
"
f"
Accuracy: {accuracy:.4f}
"
)
await flyte.report.flush.aio()
return f"Accuracy: {accuracy:.4f}"
```
## Apps and serving
Deploy long-running services.
- **FastAPI apps**: Deploy REST APIs and webhooks
β [FastAPI app](../build-apps/fastapi-app)
- **LLM serving**: Serve large language models with vLLM or SGLang
β [vLLM app](../build-apps/vllm-app), [SGLang app](../build-apps/sglang-app)
- **Autoscaling**: Scale apps up and down based on traffic, including scale-to-zero
β [Autoscaling apps](../configure-apps/auto-scaling-apps)
- **Streamlit dashboards**: Deploy interactive data dashboards
β [Streamlit app](../build-apps/streamlit-app)
## Notebooks
Work interactively.
- **Jupyter support**: Author and run workflows directly from Jupyter notebooks, and fetch workflow metadata (inputs, outputs, logs)
β [Notebooks](../task-programming/notebooks)
## Next steps
Ready to put it all together? Head to [Basic project](../basic-project) to build an end-to-end ML system with training tasks and a serving app.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/running-locally ===
# Running locally
Flyte runs locally with no cluster or Docker needed. Install the SDK, write tasks, and run them on your machine. When you're ready to scale, drop the `--local` flag and the same code runs on a remote cluster with GPUs.

---
## Getting started
If you haven't already, install the SDK and configure local persistence as described in the [Quickstart](./quickstart).
---
## Running tasks locally
The `--local` flag tells Flyte to execute a task in your local Python environment rather than on a remote cluster. Add `--tui` to launch the interactive Terminal UI for real-time monitoring.
Basic local execution:
```bash
flyte run --local my_pipeline.py my_task --arg value
```
With the interactive TUI:
```bash
flyte run --local --tui my_pipeline.py my_task --arg value
```
You can also run tasks programmatically using the Python SDK with `flyte.run()`. See [Run and deploy tasks](./task-deployment/_index) for details.
Drop `--local` to run on a remote cluster when one is configured:
```bash
flyte run my_pipeline.py my_task --arg value
```
---
## Terminal UI
The TUI is an interactive split-screen dashboard. Task tree on the left, details and logs on the right.
```bash
flyte run --local --tui my_pipeline.py pipeline --epochs 5
```

What you see:
- **Task tree** with live status: `β` running, `β` done, `β` failed
- **Cache indicators**: `$` cache hit, `~` cache enabled but missed
- **Live logs**: `print()` output streams in real time
- **Details panel**: inputs, outputs, timing, report paths
- **Traced sub-tasks**: child nodes for `@flyte.trace` decorated functions
**Keyboard shortcuts:**
| Key | Action |
|-----|--------|
| `q` | Quit |
| `d` | Details tab |
| `l` | Logs tab |
### Exploring past runs
Flyte persists the inputs and outputs of every task run locally, so you can always go back and inspect what a task received and produced. Launch the TUI on its own to browse past runs, compare inputs and outputs, and review reports:
```bash
flyte start tui
```
---
## What works locally
Most Flyte features work in both local and remote execution. The table below summarizes how each feature behaves locally.
| Feature | Local behavior | Details |
|---------|---------------|---------|
| **Caching** | Outputs stored in local SQLite, keyed on task name and inputs. Same inputs = instant results. | [Caching](./task-configuration/caching) |
| **Tracing** | `@flyte.trace` functions appear as child nodes in the TUI with their own timing, inputs, and outputs. | [Traces](./task-programming/traces) |
| **Reports** | HTML files saved locally. TUI shows the file path. | [Reports](./task-programming/reports) |
| **Serving** | Run apps locally with `python serve.py` or `flyte.with_servecontext(mode="local")`. | [Serve and deploy apps](./serve-and-deploy-apps/_index) |
| **Plugins** | Same decorators and APIs as remote. Secrets come from environment variables. | [Integrations](../integrations/_index) |
| **Secrets** | Read from `.env` files or environment variables. No `flyte create secret` needed. | [Secrets](./task-configuration/secrets) |
---
## Local to remote
The same code runs in both environments. Here's what changes:
| Aspect | Local | Remote |
|--------|-------|--------|
| **Run pipeline** | `flyte run --local` | `flyte run` |
| **TUI** | `--tui` flag | Dashboard in Flyte UI |
| **Caching** | Local SQLite | Cluster-wide distributed cache |
| **Reports** | Local HTML files | Rendered in the Flyte UI |
| **Serving** | `python serve.py` | `flyte deploy serve.py env` |
| **Secrets** | `.env` / environment variables | `flyte create secret` / `flyte.Secret` |
| **Compute** | Your CPU/GPU | `Resources(cpu=2, memory="4Gi", gpu=1)` |
The [`TaskEnvironment`](./core-concepts/task-environment) is the bridge. Locally, image and resource settings are ignored. On the cluster, Flyte builds containers and allocates compute from the same definition.
---
## Next steps
When you're ready to run on a remote Flyte cluster, see [Connecting to a cluster](./connecting-to-a-cluster) to configure the CLI and SDK.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/connecting-to-a-cluster ===
# Connecting to a cluster
This guide covers setting up your local development environment and configuring the `flyte` CLI and SDK to connect to your Union/Flyte instance.
> **π Note**
>
> Want to try Flyte without installing anything? [Try Flyte 2 in your browser](https://flyte2intro.apps.demo.hosted.unionai.cloud/).
## Prerequisites
- **Python 3.10+**
- **`uv`** β A fast Python package installer. See the [`uv` installation guide](https://docs.astral.sh/uv/getting-started/installation/).
- Access to a Union/Flyte instance (URL and a project where you can run workflows)
## Install the flyte package
Create a virtual environment and install the `flyte` package:
```bash
uv venv
source .venv/bin/activate
uv pip install flyte
```
> [!NOTE]
> On Windows, use `.venv\Scripts\activate` instead.
Verify installation:
```bash
flyte --version
```
## Configuration file
As we did in [Quickstart](./quickstart), use `flyte create config` to create a configuration file:
```bash
flyte create config \
--endpoint my-org.my-company.com \
--domain development \
--project my-project \
--builder local
```
This creates `./.flyte/config.yaml`:
```yaml
admin:
endpoint: dns:///my-org.my-company.com
image:
builder: local
task:
org: my-org
domain: development
project: my-project
```
Full example with all options
Create a custom config file with all available options:
```bash
flyte create config \
--endpoint my-org.my-company.com \
--org my-org \
--domain development \
--project my-project \
--builder local \
--insecure \
--output my-config.yaml \
--force
```
### Set up local Docker
Since Flyte OSS uses local image building, you'll need Docker running and logged into the GitHub registry:
```bash
docker login ghcr.io
```
> [!NOTE]
> The `--builder local` option means images are [built locally](./task-configuration/container-images). Union instances can use `--builder remote` instead.
See the [CLI reference](../api-reference/flyte-cli#flyte-create-config) for all parameters.
Config properties explained
**`admin`** β Connection details for your Union/Flyte instance.
- `endpoint`: URL with `dns:///` prefix. If your UI is at `https://my-org.my-company.com`, use `dns:///my-org.my-company.com`.
- `insecure`: Set to `true` only for local instances without TLS.
**`image`** β Docker image building configuration.
- `builder`: How container images are built.
- `remote` (Union): Images built on Union's infrastructure.
- `local` (Flyte OSS): Images built on your machine. Requires Docker. See [Image building](./task-configuration/container-images#image-building).
**`task`** β Default settings for task execution.
- `org`: Organization name (usually matches the first part of your endpoint URL).
- `domain`: Environment separation (`development`, `staging`, `production`).
- `project`: Default project for deployments. Must already exist on your instance. See [Projects and domains](./projects-and-domains) for how to create projects.
## Using the configuration
You can reference your config file explicitly or let the SDK find it automatically.
### Explicit configuration
### Programmatic
Initialize with [`flyte.init_from_config`](../api-reference/flyte-sdk/packages/flyte/_index#init_from_config):
```python
flyte.init_from_config("my-config.yaml")
run = flyte.run(main)
```
### CLI
Use `--config` or `-c`:
```bash
flyte --config my-config.yaml run hello.py main
flyte -c my-config.yaml run hello.py main
```
Configuration precedence
Without an explicit path, the SDK searches these locations in order:
1. `./config.yaml`
2. `./.flyte/config.yaml`
3. `UCTL_CONFIG` environment variable
4. `FLYTECTL_CONFIG` environment variable
5. `~/.union/config.yaml`
6. `~/.flyte/config.yaml`
### Programmatic
```python
flyte.init_from_config()
```
### CLI
```bash
flyte run hello.py main
```
### Check current configuration
```bash
flyte get config
```
Output:
```bash
CLIConfig(
Config(
platform=PlatformConfig(endpoint='dns:///my-org.my-company.com', scopes=[]),
task=TaskConfig(org='my-org', project='my-project', domain='development'),
source=PosixPath('/Users/me/.flyte/config.yaml')
),
...
)
```
## Inline configuration
Skip the config file entirely by passing parameters directly.
### Programmatic
Use [`flyte.init`](../api-reference/flyte-sdk/packages/flyte/_index#init):
```python
flyte.init(
endpoint="dns:///my-org.my-company.com",
org="my-org",
project="my-project",
domain="development",
)
```
### CLI
Some parameters go after `flyte`, others after the subcommand:
```bash
flyte \
--endpoint my-org.my-company.com \
--org my-org \
run \
--domain development \
--project my-project \
hello.py \
main
```
See the [CLI reference](../api-reference/flyte-cli) for details.
See related methods:
* [`flyte.init_from_api_key`](../api-reference/flyte-sdk/packages/flyte/_index#init_from_api_key)
* [`flyte.init_from_config`](../api-reference/flyte-sdk/packages/flyte/_index#init_from_config)
* [`flyte.init_in_cluster`](../api-reference/flyte-sdk/packages/flyte/_index#init_in_cluster)
* [`flyte.init_passthrough`](../api-reference/flyte-sdk/packages/flyte/_index#init_passthrough)
## Next steps
With your environment fully configured, you're ready to build:
- [**Core concepts**](./core-concepts/_index): Understand `TaskEnvironment`s, tasks, runs, and actions through working examples.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/projects-and-domains ===
# Projects and domains
Flyte organizes work into a hierarchy of **organization**, **projects**, and **domains**.
- **Organization**: Your Flyte instance, typically representing a company or department. Set up during onboarding and mapped to your endpoint URL (e.g., `my-org.my-company.com`). You do not create or manage organizations directly. The organization is normally determined automatically from your endpoint URL, but you can override it with the `--org` flag on any CLI command (e.g., `flyte --org my-org get project`). This is only relevant if you have a multi-organization installation.
- **Project**: A logical grouping of related workflows, tasks, launch plans, and executions. Projects are the primary unit you create and manage.
- **Domain**: An environment classification within each project. Three fixed domains exist: `development`, `staging`, and `production`. Domains cannot be created or deleted.
Every project contains all three domains, creating **project-domain pairs** like `my-project/development`, `my-project/staging`, and `my-project/production`. Workflows, executions, and data are scoped to a specific project-domain pair.
## How projects and domains are used
When you run or deploy workflows, you target a project and domain:
- **CLI**: Use `--project` and `--domain` flags with `flyte run` or `flyte deploy`, or set defaults in your [configuration file](./connecting-to-a-cluster).
- **Python SDK**: Specify `project` and `domain` in [`flyte.init`](../api-reference/flyte-sdk/packages/flyte/_index#init) or [`flyte.init_from_config`](../api-reference/flyte-sdk/packages/flyte/_index#init_from_config).
Projects and domains also determine data isolation. Storage and cache are isolated per project-domain pair.
## Managing projects via CLI
### Create a project
```shell
flyte create project --id my-project --name "My Project"
```
The `--id` is a unique identifier used in CLI commands and configuration (immutable once set). The `--name` is a human-readable display name.
You can also add a description and labels:
```shell
flyte create project \
--id my-project \
--name "My Project" \
--description "ML platform workflows" \
-l team=ml-platform \
-l env=prod
```
Labels are specified as `-l key=value` and can be repeated.
### List projects
List all active projects:
```shell
flyte get project
```
Get details of a specific project:
```shell
flyte get project my-project
```
List archived projects:
```shell
flyte get project --archived
```
### Update a project
Update the name, description, or labels of a project:
```shell
flyte update project my-project --description "Updated description"
flyte update project my-project --name "New Display Name"
flyte update project my-project -l team=ml -l env=staging
```
> [!NOTE]
> Setting labels replaces all existing labels on the project.
### Archive a project
Archiving a project hides it from default listings but does not delete its data:
```shell
flyte update project my-project --archive
```
### Unarchive a project
Restore an archived project to active status:
```shell
flyte update project my-project --unarchive
```
## Listing projects programmatically
You can list and retrieve projects from Python using [`flyte.remote.Project`](../api-reference/flyte-sdk/packages/flyte.remote/project/_index):
```python
import flyte
flyte.init_from_config()
# Get a specific project
project = flyte.remote.Project.get(name="my-project", org="my-org")
# List all projects
for project in flyte.remote.Project.listall():
print(project.to_dict())
# List with filtering and sorting
for project in flyte.remote.Project.listall(sort_by=("created_at", "desc")):
print(project.to_dict())
```
Both `get()` and `listall()` support async execution via `.aio()`:
```python
project = await flyte.remote.Project.get.aio(name="my-project", org="my-org")
```
> [!NOTE]
> The Python SDK provides read-only access to projects. To create or modify projects, use the `flyte` CLI or the UI.
## Managing projects via the UI
When you log in to your Flyte instance, you land on the **Projects** page, which lists all projects in your organization. By default, the domain is set to `development`. You can change the active domain using the selector in the top left.
A **Recently viewed** list on the left sidebar provides quick access to your most commonly used projects.
From the project list you can:
* **Open a project**: Select a project from the list to navigate to it.
* **Create a project**: Click **+ New project** in the top right. In the dialog, specify a name and description. The project will be created across all three domains.
* **Archive a project**: Click the three-dot menu on a project's entry and select **Archive project**.
## Domains
Domains provide environment separation within each project. The three domains are:
| Domain | Purpose |
|--------|---------|
| `development` | For iterating on workflows during active development. |
| `staging` | For testing workflows before promoting to production. |
| `production` | For production workloads. |
Domains are predefined and cannot be created, renamed, or deleted.
### Targeting a domain
Set the default domain in your configuration file:
```yaml
task:
domain: development
```
Or override per command:
```shell
flyte run --domain staging hello.py main
```
When using `flyte deploy`, the domain determines where the deployed workflows will execute:
```shell
flyte deploy --project my-project --domain production workflows
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/basic-project ===
# Basic project: RAG
This example demonstrates a two-stage RAG (Retrieval-Augmented Generation) pattern:
an offline embedding pipeline that processes and stores quotes, followed by an online
serving application that enables semantic search.
## Concepts covered
- `TaskEnvironment` for defining task execution environments
- `Dir` artifacts for passing directories between tasks
- `AppEnvironment` for serving applications
- `Parameter` and `RunOutput` for connecting apps to task outputs
- Semantic search with sentence-transformers and ChromaDB
## Part 1: The embedding pipeline
The embedding pipeline fetches quotes from a public API, creates vector embeddings
using sentence-transformers, and stores them in a ChromaDB database.
### Setting up the environment
The `TaskEnvironment` defines the execution environment for all tasks in the pipeline.
It specifies the container image, required packages, and resource allocations:
```python
# Define the embedding environment
embedding_env = flyte.TaskEnvironment(
name="quote-embedding",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"sentence-transformers>=2.2.0",
"chromadb>=0.4.0",
"requests>=2.31.0",
),
resources=flyte.Resources(cpu=2, memory="4Gi"),
cache="auto",
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/embed.py*
The environment uses:
- `Image.from_debian_base()` to create a container with Python 3.12
- `with_pip_packages()` to install sentence-transformers and ChromaDB
- `Resources` to request 2 CPUs and 4GB of memory
- `cache="auto"` to enable automatic caching of task outputs
### Fetching data
The `fetch_quotes` task retrieves quotes from a public API:
```python
@embedding_env.task
async def fetch_quotes() -> list[dict]:
"""
Fetch quotes from a public quotes API.
Returns:
List of quote dictionaries with 'quote' and 'author' fields.
"""
import requests
print("Fetching quotes from API...")
response = requests.get("https://dummyjson.com/quotes?limit=100")
response.raise_for_status()
data = response.json()
quotes = data.get("quotes", [])
print(f"Fetched {len(quotes)} quotes")
return quotes
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/embed.py*
This task demonstrates:
- Async task definition with `async def`
- Returning structured data (`list[dict]`) from a task
- Using the `@embedding_env.task` decorator to associate the task with its environment
### Creating embeddings
The `embed_quotes` task creates vector embeddings and stores them in ChromaDB:
```python
@embedding_env.task
async def embed_quotes(quotes: list[dict]) -> Dir:
"""
Create embeddings for quotes and store them in ChromaDB.
Args:
quotes: List of quote dictionaries with 'quote' and 'author' fields.
Returns:
Directory containing the ChromaDB database.
"""
import chromadb
from sentence_transformers import SentenceTransformer
print("Loading embedding model...")
model = SentenceTransformer("all-MiniLM-L6-v2")
# Create ChromaDB in a temporary directory
db_dir = tempfile.mkdtemp()
print(f"Creating ChromaDB at {db_dir}...")
client = chromadb.PersistentClient(path=db_dir)
collection = client.create_collection(
name="quotes",
metadata={"hnsw:space": "cosine"},
)
# Prepare data for insertion
texts = [q["quote"] for q in quotes]
ids = [str(q["id"]) for q in quotes]
metadatas = [{"author": q["author"], "quote": q["quote"]} for q in quotes]
print(f"Embedding {len(texts)} quotes...")
embeddings = model.encode(texts, show_progress_bar=True)
# Add to collection
collection.add(
ids=ids,
embeddings=embeddings.tolist(),
metadatas=metadatas,
documents=texts,
)
print(f"Stored {len(quotes)} quotes in ChromaDB")
return await Dir.from_local(db_dir)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/embed.py*
Key points:
- Uses the `all-MiniLM-L6-v2` model from sentence-transformers (runs on CPU)
- Creates a persistent ChromaDB database with cosine similarity
- Returns a `Dir` artifact that captures the entire database directory
- The `await Dir.from_local()` call uploads the directory to artifact storage
### Orchestrating the pipeline
The main pipeline task composes the individual tasks:
```python
@embedding_env.task
async def embedding_pipeline() -> Dir:
"""
Main pipeline that fetches quotes and creates embeddings.
Returns:
Directory containing the ChromaDB database with quote embeddings.
"""
print("Starting embedding pipeline...")
# Fetch quotes from API
quotes = await fetch_quotes()
# Create embeddings and store in ChromaDB
db_dir = await embed_quotes(quotes)
print("Embedding pipeline complete!")
return db_dir
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/embed.py*
### Running the pipeline
To run the embedding pipeline:
```python
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(embedding_pipeline)
print(f"Embedding run URL: {run.url}")
run.wait()
print(f"Embedding complete! Database directory: {run.outputs()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/embed.py*
```bash
uv run embed.py
```
The pipeline will:
1. Fetch 100 quotes from the API
2. Create embeddings using sentence-transformers
3. Store everything in a ChromaDB database
4. Return the database as a `Dir` artifact
## Part 2: The serving application
The serving application provides a Streamlit web interface for searching quotes
using the embeddings created by the pipeline.
### App environment configuration
The `AppEnvironment` defines how the application runs:
```python
# Define the app environment
env = AppEnvironment(
name="quote-search-app",
description="Semantic search over quotes using embeddings",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit>=1.41.0",
"sentence-transformers>=2.2.0",
"chromadb>=0.4.0",
),
args=["streamlit", "run", "app.py", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu=2, memory="4Gi"),
parameters=[
Parameter(
name="quotes_db",
value=RunOutput(task_name="quote-embedding.embedding_pipeline", type="directory"),
download=True,
env_var="QUOTES_DB_PATH",
),
],
include=["app.py"],
requires_auth=False,
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/serve.py*
Key configuration:
- `args` specifies the command to run the Streamlit app
- `port=8080` exposes the application on port 8080
- `parameters` defines inputs to the app:
- `RunOutput` connects to the embedding pipeline's output
- `download=True` downloads the directory to local storage
- `env_var="QUOTES_DB_PATH"` makes the path available to the app
- `include=["app.py"]` bundles the Streamlit app with the deployment
### The Streamlit application
The app loads the ChromaDB database using the path from the environment variable:
```python
# Load the database
@st.cache_resource
def load_db():
db_path = os.environ.get("QUOTES_DB_PATH")
if not db_path:
st.error("QUOTES_DB_PATH environment variable not set")
st.stop()
client = chromadb.PersistentClient(path=db_path)
collection = client.get_collection("quotes")
model = SentenceTransformer("all-MiniLM-L6-v2")
return collection, model
collection, model = load_db()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/app.py*
The search interface provides a text input and result count slider:
```python
# Search interface
query = st.text_input("Enter your search query:", placeholder="e.g., love, wisdom, success")
top_k = st.slider("Number of results:", min_value=1, max_value=20, value=5)
col1, col2 = st.columns([1, 1])
with col1:
search_button = st.button("Search", type="primary", use_container_width=True)
with col2:
random_button = st.button("Random Quote", use_container_width=True)
st.divider()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/app.py*
When the user searches, the app encodes the query and finds similar quotes:
```python
if search_button and query:
# Encode query and search
query_embedding = model.encode([query])[0].tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
)
if results["documents"] and results["documents"][0]:
for i, (doc, metadata, distance) in enumerate(
zip(results["documents"][0], results["metadatas"][0], results["distances"][0])
):
similarity = 1 - distance # Convert distance to similarity
st.markdown(f'**{i+1}.** "{doc}"')
st.caption(f"β {metadata['author']} | Similarity: {similarity:.2%}")
st.write("")
else:
st.info("No results found.")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/app.py*
The app also includes a random quote feature:
```python
elif random_button:
# Get a random quote from the collection
all_data = collection.get(limit=100)
if all_data["documents"]:
idx = random.randint(0, len(all_data["documents"]) - 1)
quote = all_data["documents"][idx]
author = all_data["metadatas"][idx]["author"]
st.markdown(f'**"{quote}"**')
st.caption(f"β {author}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/app.py*
### Deploying the app
To deploy the quote search application:
```python
if __name__ == "__main__":
flyte.init_from_config()
# Deploy the quote search app
print("Deploying quote search app...")
deployment = flyte.serve(env)
print(f"App deployed at: {deployment.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/basic-project/serve.py*
```bash
uv run serve.py
```
The app will be deployed and automatically connected to the embedding pipeline's
output through the `RunOutput` parameter.
## Key takeaways
1. **Two-stage RAG pattern**: Separate offline embedding creation from online serving
for better resource utilization and cost efficiency.
2. **Dir artifacts**: Use `Dir` to pass entire directories (like databases) between
tasks and to serving applications.
3. **RunOutput**: Connect applications to task outputs declaratively, enabling
automatic data flow between pipelines and apps.
4. **CPU-friendly embeddings**: The `all-MiniLM-L6-v2` model runs efficiently on CPU,
making this pattern accessible without GPU resources.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/advanced-project ===
# Advanced project: LLM reporting agent
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This example demonstrates a resilient agentic report generator that showcases
Flyte 2.0's advanced features for building production-grade AI workflows.
## What you'll build
A batch report generator that:
1. Processes multiple topics in parallel
2. Iteratively critiques and refines each report until it meets a quality threshold
3. Produces multiple output formats (Markdown, HTML, summary) for each report
4. Serves results through an interactive UI
## Concepts covered
| Feature | Description |
|---------|-------------|
| `ReusePolicy` | Keep containers warm for high-throughput batch processing |
| `@flyte.trace` | Checkpoint LLM calls for recovery and observability |
| `RetryStrategy` | Handle transient API failures gracefully |
| `flyte.group` | Organize parallel batches and iterations in the UI |
| `asyncio.gather` | Fan out to process multiple topics concurrently |
| Pydantic models | Structured LLM outputs |
| `AppEnvironment` | Deploy interactive Streamlit apps |
| `RunOutput` | Connect apps to pipeline outputs |
## Architecture
```mermaid
flowchart TD
A[Topics List] --> B
B["report_batch_pipeline driver_env"]
subgraph B1 ["refine_all (parallel)"]
direction LR
R1["refine_report topic 1"]
R2["refine_report topic 2"]
R3["refine_report topic N"]
end
B --> B1
subgraph B2 ["format_all (parallel)"]
direction LR
F1["format_outputs report 1"]
F2["format_outputs report 2"]
F3["format_outputs report N"]
end
B1 --> B2
B2 --> C["Output: List of Dirs"]
```
Each `refine_report` task runs in a reusable container (`llm_env`) and performs
multiple LLM calls through traced functions:
```mermaid
flowchart TD
A[Topic] --> B["generate_initial_draft @flyte.trace"]
B --> C
subgraph C ["refinement_loop"]
direction TB
D["critique_content @flyte.trace"] -->|score >= threshold| E[exit loop]
D -->|score < threshold| F["revise_content @flyte.trace"]
F --> D
end
C --> G[Refined Report]
```
## Prerequisites
- A Flyte account with an active project
- An OpenAI API key stored as a secret named `openai-api-key`
To create the secret:
```bash
flyte secret create openai-api-key
```
## Parts
1. ****Advanced project: LLM reporting agent > Resilient generation****: Set up reusable environments, traced LLM calls, and retry strategies
2. ****Advanced project: LLM reporting agent > Agentic refinement****: Build the iterative critique-and-revise loop
3. ****Advanced project: LLM reporting agent > Parallel outputs****: Generate multiple formats concurrently
4. ****Advanced project: LLM reporting agent > Serving app****: Deploy an interactive UI for report generation
[Resilient generation]()
## Key takeaways
1. **Reusable environments for batch processing**: `ReusePolicy` keeps containers warm,
enabling efficient processing of multiple topics without cold start overhead. With
5 topics Γ ~7 LLM calls each, the reusable pool handles ~35 calls efficiently.
2. **Checkpointed LLM calls**: `@flyte.trace` provides automatic checkpointing at the
function level, enabling recovery without re-running expensive API calls.
3. **Agentic patterns**: The generate-critique-revise loop demonstrates how to build
self-improving AI workflows with clear observability through `flyte.group`.
4. **Parallel fan-out**: `asyncio.gather` processes multiple topics concurrently,
maximizing throughput by running refinement tasks in parallel across the batch.
## Subpages
- **Advanced project: LLM reporting agent > Resilient generation**
- **Advanced project: LLM reporting agent > Agentic refinement**
- **Advanced project: LLM reporting agent > Parallel outputs**
- **Advanced project: LLM reporting agent > Serving app**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/advanced-project/resilient-generation ===
# Resilient generation
This section covers the foundational patterns for building resilient LLM-powered
tasks: reusable environments, traced function calls, and retry strategies.
## Two environments
This example uses two task environments with different characteristics:
1. **`llm_env`** (reusable): For tasks that make many LLM calls in a loop or
process batches in parallel. Container reuse avoids cold starts.
2. **`driver_env`** (standard): For orchestration tasks that fan out work to
other tasks but don't make LLM calls themselves.
### Reusable environment for LLM work
When processing a batch of topics, each topic goes through multiple LLM calls
(generate, critique, revise, repeat). With 5 topics Γ ~7 calls each, that's ~35
LLM calls. `ReusePolicy` keeps containers warm to handle this efficiently:
```python
# Reusable environment for tasks that make many LLM calls in a loop.
# The ReusePolicy keeps containers warm, reducing cold start latency for iterative work.
llm_env = flyte.TaskEnvironment(
name="llm-worker",
secrets=[] if MOCK_MODE else [flyte.Secret(key="openai-api-key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"unionai-reuse>=0.1.10",
"openai>=1.0.0",
"pydantic>=2.0.0",
),
resources=flyte.Resources(cpu=1, memory="2Gi"),
reusable=flyte.ReusePolicy(
replicas=2, # Keep 2 container instances ready
concurrency=4, # Allow 4 concurrent tasks per container
scaledown_ttl=timedelta(minutes=5), # Wait 5 min before scaling down
idle_ttl=timedelta(minutes=30), # Shut down after 30 min idle
),
cache="auto",
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### ReusePolicy parameters
| Parameter | Description |
|-----------|-------------|
| `replicas` | Number of container instances to keep ready (or `(min, max)` tuple) |
| `concurrency` | Maximum tasks per container at once |
| `scaledown_ttl` | Minimum wait before scaling down a replica |
| `idle_ttl` | Time after which idle containers shut down completely |
The configuration above keeps 2 containers ready, allows 4 concurrent tasks per
container, waits 5 minutes before scaling down, and shuts down after 30 minutes
of inactivity.
> **π Note**
>
> Both `scaledown_ttl` and `idle_ttl` must be at least 30 seconds.
### Standard environment for orchestration
The driver environment doesn't need container reuseβit just coordinates work.
The `depends_on` parameter declares that tasks in this environment call tasks
in `llm_env`, ensuring both environments are deployed together:
```python
# Standard environment for orchestration tasks that don't need container reuse.
# depends_on declares that this environment's tasks call tasks in llm_env.
driver_env = flyte.TaskEnvironment(
name="driver",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"pydantic>=2.0.0",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
depends_on=[llm_env],
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
## Traced LLM calls
The `@flyte.trace` decorator provides automatic checkpointing at the function level.
When a traced function completes successfully, its result is cached. If the task
fails and restarts, previously completed traced calls return their cached results
instead of re-executing.
```python
@flyte.trace
async def call_llm(prompt: str, system: str, json_mode: bool = False) -> str:
"""
Make an LLM call with automatic checkpointing.
The @flyte.trace decorator provides:
- Automatic caching of results for identical inputs
- Recovery from failures without re-running successful calls
- Full observability in the Flyte UI
Args:
prompt: The user prompt to send
system: The system prompt defining the LLM's role
json_mode: Whether to request JSON output
Returns:
The LLM's response text
"""
# Use mock responses for testing without API keys
if MOCK_MODE:
import asyncio
await asyncio.sleep(0.5) # Simulate API latency
if "critique" in prompt.lower() or "critic" in system.lower():
# Return good score if draft has been revised (contains revision marker)
if "[REVISED]" in prompt:
return MOCK_CRITIQUE_GOOD
return MOCK_CRITIQUE_NEEDS_WORK
elif "summary" in system.lower():
return MOCK_SUMMARY
elif "revis" in system.lower():
# Return revised version with marker
return MOCK_REPORT.replace("## Introduction", "[REVISED]\n\n## Introduction")
else:
return MOCK_REPORT
from openai import AsyncOpenAI
client = AsyncOpenAI()
kwargs = {
"model": "gpt-4o-mini",
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": prompt},
],
"max_tokens": 2000,
}
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
response = await client.chat.completions.create(**kwargs)
return response.choices[0].message.content
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### Benefits of tracing
1. **Cost savings**: Failed tasks don't re-run expensive API calls that already succeeded
2. **Faster recovery**: Resuming from checkpoints skips completed work
3. **Observability**: Each traced call appears in the Flyte UI with timing data
### When to use @flyte.trace
Use `@flyte.trace` for:
- LLM API calls (OpenAI, Anthropic, etc.)
- External API requests
- Any expensive operation you don't want to repeat on retry
Don't use `@flyte.trace` for:
- Simple computations (overhead outweighs benefit)
- Operations with side effects that shouldn't be skipped
## Traced helper functions
The LLM-calling functions are decorated with `@flyte.trace` rather than being
separate tasks. This keeps the architecture simple while still providing
checkpointing:
```python
@flyte.trace
async def generate_initial_draft(topic: str) -> str:
"""
Generate the initial report draft.
The @flyte.trace decorator provides checkpointing - if the task fails
after this completes, it won't re-run on retry.
Args:
topic: The topic to write about
Returns:
The initial draft in markdown format
"""
print(f"Generating initial draft for topic: {topic}")
prompt = f"Write a comprehensive report on the following topic:\n\n{topic}"
draft = await call_llm(prompt, GENERATOR_SYSTEM_PROMPT)
print(f"Generated initial draft ({len(draft)} characters)")
return draft
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
These traced functions run inside the `refine_report` task. If the task fails
and retries, completed traced calls return cached results instead of re-executing.
## Retry strategies
The task that orchestrates the LLM calls uses `retries` to handle transient failures:
```python
@llm_env.task(retries=3)
async def refine_report(topic: str, ...) -> str:
# Traced functions are called here
draft = await generate_initial_draft(topic)
...
```
### Configuring retries
You can specify retries as a simple integer:
```python
@llm_env.task(retries=3)
async def my_task():
...
```
Or use `RetryStrategy` for more control:
```python
@llm_env.task(retries=flyte.RetryStrategy(count=3))
async def my_task():
...
```
### Combining tracing with retries
When you combine `@flyte.trace` with task-level retries, you get the best of both:
1. Task fails after completing some traced calls
2. Flyte retries the task
3. Previously completed traced calls return cached results
4. Only the failed operation (and subsequent ones) re-execute
This pattern is essential for multi-step LLM workflows where you don't want to
re-run the entire chain when a single call fails.
## Structured prompts
The example uses a separate `prompts.py` module for system prompts and Pydantic models:
```python
GENERATOR_SYSTEM_PROMPT = """You are an expert report writer. Generate a well-structured,
informative report on the given topic. The report should include:
1. An engaging introduction that sets context
2. Clear sections with descriptive headings
3. Specific facts, examples, or data points where relevant
4. A conclusion that summarizes key takeaways
Write in a professional but accessible tone. Use markdown formatting for structure.
Aim for approximately 500-800 words."""
CRITIC_SYSTEM_PROMPT = """You are a demanding but fair editor reviewing a report draft.
Evaluate the report on these criteria:
- Clarity: Is the writing clear and easy to follow?
- Structure: Is it well-organized with logical flow?
- Depth: Does it provide sufficient detail and insight?
- Accuracy: Are claims supported and reasonable?
- Engagement: Is it interesting to read?
Provide your response as JSON matching this schema:
{
"score": <1-10 integer>,
"strengths": ["strength 1", "strength 2", ...],
"improvements": ["improvement 1", "improvement 2", ...],
"summary": "brief overall assessment"
}
Be specific in your feedback. A score of 8+ means the report is ready for publication."""
REVISER_SYSTEM_PROMPT = """You are an expert editor revising a report based on feedback.
Your task is to improve the report by addressing the specific improvements requested
while preserving its strengths.
Guidelines:
- Address each improvement point specifically
- Maintain the original voice and style
- Keep the same overall structure unless restructuring is requested
- Preserve any content that was praised as a strength
- Ensure the revised version is cohesive and flows well
Return only the revised report in markdown format, no preamble or explanation."""
SUMMARY_SYSTEM_PROMPT = """Create a concise executive summary (2-3 paragraphs) of the
following report. Capture the key points and main takeaways. Write in a professional
tone suitable for busy executives who need the essential information quickly."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/prompts.py*
### Pydantic models for structured output
LLM responses can be unpredictable. Using Pydantic models with JSON mode ensures
you get structured, validated data:
```python
class Critique(BaseModel):
"""Structured critique response from the LLM."""
score: int = Field(
ge=1,
le=10,
description="Quality score from 1-10, where 10 is publication-ready",
)
strengths: list[str] = Field(
description="List of strengths in the current draft",
)
improvements: list[str] = Field(
description="Specific improvements needed",
)
summary: str = Field(
description="Brief summary of the critique",
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/prompts.py*
The `Critique` model validates that:
- `score` is an integer between 1 and 10
- `strengths` and `improvements` are lists of strings
- All required fields are present
If the LLM returns malformed JSON, Pydantic raises a validation error, which
triggers a retry (if configured).
## Next steps
With resilient generation in place, you're ready to build the
[agentic refinement loop](./agentic-refinement).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/advanced-project/agentic-refinement ===
# Agentic refinement
The core of this example is an agentic refinement loop: generate content, critique
it, revise based on feedback, and repeat until quality meets a threshold. This
pattern is fundamental to building self-improving AI systems.
## The agentic pattern
Traditional pipelines are linear: input β process β output. Agentic workflows
are iterative: they evaluate their own output and improve it through multiple
cycles.
```mermaid
flowchart TD
A[Generate] --> B[Critique]
B -->|score >= threshold| C[Done]
B -->|score < threshold| D[Revise]
D --> B
```
## Critique function
The critique function evaluates the current draft and returns structured feedback.
It's a traced function (not a separate task) that runs inside `refine_report`:
```python
@flyte.trace
async def critique_content(draft: str) -> Critique:
"""
Critique the current draft and return structured feedback.
Uses Pydantic models to parse the LLM's JSON response into
a typed object for reliable downstream processing.
Args:
draft: The current draft to critique
Returns:
Structured critique with score, strengths, and improvements
"""
print("Critiquing current draft...")
response = await call_llm(
f"Please critique the following report:\n\n{draft}",
CRITIC_SYSTEM_PROMPT,
json_mode=True,
)
# Parse the JSON response into our Pydantic model
critique_data = json.loads(response)
critique = Critique(**critique_data)
print(f"Critique score: {critique.score}/10")
print(f"Strengths: {len(critique.strengths)}, Improvements: {len(critique.improvements)}")
return critique
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
Key points:
- Uses `json_mode=True` to ensure the LLM returns valid JSON
- Parses the response into a Pydantic `Critique` model
- Returns a typed object for reliable downstream processing
- `@flyte.trace` provides checkpointingβif the task retries, completed critiques aren't re-run
## Revise function
The revise function takes the current draft and specific improvements to address:
```python
@flyte.trace
async def revise_content(draft: str, improvements: list[str]) -> str:
"""
Revise the draft based on critique feedback.
Args:
draft: The current draft to revise
improvements: List of specific improvements to address
Returns:
The revised draft
"""
print(f"Revising draft to address {len(improvements)} improvements...")
improvements_text = "\n".join(f"- {imp}" for imp in improvements)
prompt = f"""Please revise the following report to address these improvements:
IMPROVEMENTS NEEDED:
{improvements_text}
CURRENT DRAFT:
{draft}"""
revised = await call_llm(prompt, REVISER_SYSTEM_PROMPT)
print(f"Revision complete ({len(revised)} characters)")
return revised
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
The prompt includes:
1. The list of improvements from the critique
2. The current draft to revise
This focused approach helps the LLM make targeted changes rather than rewriting
from scratch.
## The refinement loop
The `refine_report` task orchestrates the iterative refinement. It runs in the
reusable `llm_env` because it makes multiple LLM calls through traced functions:
```python
@llm_env.task(retries=3)
async def refine_report(
topic: str,
max_iterations: int = 3,
quality_threshold: int = 8,
) -> str:
"""
Iteratively refine a report until it meets the quality threshold.
This task runs in a reusable container because it makes multiple LLM calls
in a loop. The traced helper functions provide checkpointing, so if the
task fails mid-loop, completed LLM calls won't be re-run on retry.
Args:
topic: The topic to write about
max_iterations: Maximum refinement cycles (default: 3)
quality_threshold: Minimum score to accept (default: 8)
Returns:
The final refined report
"""
# Generate initial draft
draft = await generate_initial_draft(topic)
# Iterative refinement loop
for i in range(max_iterations):
with flyte.group(f"refinement_{i + 1}"):
# Get critique
critique = await critique_content(draft)
# Check if we've met the quality threshold
if critique.score >= quality_threshold:
print(f"Quality threshold met at iteration {i + 1}!")
print(f"Final score: {critique.score}/10")
break
# Revise based on feedback
print(f"Score {critique.score} < {quality_threshold}, revising...")
draft = await revise_content(draft, critique.improvements)
else:
print(f"Reached max iterations ({max_iterations})")
return draft
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### How it works
1. **Generate initial draft**: Creates the first version of the report
2. **Enter refinement loop**: Iterates up to `max_iterations` times
3. **Critique**: Evaluates the current draft and assigns a score
4. **Check threshold**: If score meets `quality_threshold`, exit early
5. **Revise**: If below threshold, revise based on improvements
6. **Repeat**: Continue until threshold met or iterations exhausted
All the LLM calls (generate, critique, revise) are traced functions inside this
single task. This keeps the task graph simple while the reusable container handles
the actual LLM work efficiently.
### Early exit
The `if critique.score >= quality_threshold: break` pattern enables early exit
when quality is sufficient. This saves compute costs and timeβno need to run
all iterations if the first draft is already good.
## Grouping iterations with flyte.group
Each refinement iteration is wrapped in `flyte.group`:
```python
for i in range(max_iterations):
with flyte.group(f"refinement_{i + 1}"):
critique = await critique_content(draft)
# ...
```
### Why use flyte.group?
Groups provide hierarchical organization in the Flyte UI. Since critique and
revise are traced functions (not separate tasks), groups help organize them:
```
refine_report
βββ generate_initial_draft (traced)
βββ refinement_1
β βββ critique_content (traced)
β βββ revise_content (traced)
βββ refinement_2
β βββ critique_content (traced)
β βββ revise_content (traced)
βββ [returns refined report]
```
Benefits:
- **Clarity**: See exactly how many iterations occurred
- **Debugging**: Quickly find which iteration had issues
- **Observability**: Track time spent in each refinement cycle
### Group context
Groups are implemented as context managers. All traced calls and nested groups
within the `with flyte.group(...)` block are associated with that group.
## Configuring the loop
The refinement loop accepts parameters to tune its behavior:
| Parameter | Default | Description |
|-----------|---------|-------------|
| `max_iterations` | 3 | Upper bound on refinement cycles |
| `quality_threshold` | 8 | Minimum score (1-10) to accept |
### Choosing thresholds
- **Higher threshold** (9-10): More refinement cycles, higher quality, more API costs
- **Lower threshold** (6-7): Faster completion, may accept lower quality
- **More iterations**: Safety net for difficult topics
- **Fewer iterations**: Cost control, faster turnaround
A good starting point is `quality_threshold=8` with `max_iterations=3`. Adjust
based on your quality requirements and budget.
## Best practices for agentic loops
1. **Always set max iterations**: Prevent infinite loops if the quality threshold
is never reached.
2. **Use structured critiques**: Pydantic models ensure you can reliably extract
the score and improvements from LLM responses.
3. **Log iteration progress**: Print statements help debug when reviewing logs:
```python
print(f"Iteration {i + 1}: score={critique.score}")
```
4. **Consider diminishing returns**: After 3-4 iterations, improvements often
become marginal. Set `max_iterations` accordingly.
5. **Use groups for observability**: `flyte.group` makes the iterative nature
visible in the UI, essential for debugging and monitoring.
## Next steps
With the agentic refinement loop complete, learn how to
[generate multiple outputs in parallel](./parallel-outputs).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/advanced-project/parallel-outputs ===
# Parallel outputs
After refining the report, the pipeline generates multiple output formats in
parallel. This demonstrates how to use `asyncio.gather` for concurrent execution
within a task.
## The formatting functions
The pipeline generates three outputs: markdown, HTML, and an executive summary.
Only `generate_summary` uses `@flyte.trace` because it makes an LLM call.
The markdown and HTML functions are simple, deterministic transformations that
don't benefit from checkpointing:
```python
async def format_as_markdown(content: str) -> str:
"""Format the report as clean markdown."""
# Content is already markdown, but we could add TOC, metadata, etc.
return f"""---
title: Generated Report
date: {__import__('datetime').datetime.now().isoformat()}
---
{content}
"""
async def format_as_html(content: str) -> str:
"""Convert the report to HTML."""
# Simple markdown to HTML conversion
import re
html = content
# Convert headers
html = re.sub(r"^### (.+)$", r"
\1
", html, flags=re.MULTILINE)
html = re.sub(r"^## (.+)$", r"
\1
", html, flags=re.MULTILINE)
html = re.sub(r"^# (.+)$", r"
\1
", html, flags=re.MULTILINE)
# Convert bold/italic
html = re.sub(r"\*\*(.+?)\*\*", r"\1", html)
html = re.sub(r"\*(.+?)\*", r"\1", html)
# Convert paragraphs
html = re.sub(r"\n\n", r"
", html)
return f"""
Generated Report
{html}
"""
@flyte.trace
async def generate_summary(content: str) -> str:
"""Generate an executive summary of the report."""
return await call_llm(content, SUMMARY_SYSTEM_PROMPT)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### When to trace and when not to
Use `@flyte.trace` for operations that are expensive, non-deterministic, or
call external APIs (like `generate_summary`). Skip it for cheap, deterministic
transformations (like `format_as_markdown` and `format_as_html`) where
re-running on retry is trivial.
## Parallel execution with asyncio.gather
The `format_outputs` task runs all formatters concurrently:
```python
@llm_env.task
async def format_outputs(content: str) -> Dir:
"""
Generate multiple output formats in parallel.
Uses asyncio.gather to run all formatting operations concurrently,
maximizing efficiency when each operation is I/O-bound.
Args:
content: The final report content
Returns:
Directory containing all formatted outputs
"""
print("Generating output formats in parallel...")
with flyte.group("formatting"):
# Run all formatting operations in parallel
markdown, html, summary = await asyncio.gather(
format_as_markdown(content),
format_as_html(content),
generate_summary(content),
)
# Write outputs to a directory
output_dir = tempfile.mkdtemp()
with open(os.path.join(output_dir, "report.md"), "w") as f:
f.write(markdown)
with open(os.path.join(output_dir, "report.html"), "w") as f:
f.write(html)
with open(os.path.join(output_dir, "summary.txt"), "w") as f:
f.write(summary)
print(f"Created outputs in {output_dir}")
return await Dir.from_local(output_dir)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### How asyncio.gather works
`asyncio.gather` takes multiple coroutines and runs them concurrently:
```python
markdown, html, summary = await asyncio.gather(
format_as_markdown(content), # Starts immediately
format_as_html(content), # Starts immediately
generate_summary(content), # Starts immediately
)
# All three run concurrently, results returned in order
```
Without `gather`, these would run sequentially:
```python
# Sequential (slower)
markdown = await format_as_markdown(content) # Wait for completion
html = await format_as_html(content) # Then start this
summary = await generate_summary(content) # Then start this
```
### When to use asyncio.gather
Use `asyncio.gather` when:
- Operations are independent (don't depend on each other's results)
- Operations are I/O-bound (API calls, file operations)
- You want to minimize total execution time
Don't use `asyncio.gather` when:
- Operations depend on each other
- Operations are CPU-bound (use process pools instead)
- Order of execution matters for side effects
## Grouping parallel operations
The parallel formatting is wrapped in a group for UI clarity:
```python
with flyte.group("formatting"):
markdown, html, summary = await asyncio.gather(...)
```
In the Flyte UI, the traced call within the group is visible:
```
format_outputs
βββ formatting
βββ format_as_markdown
βββ format_as_html
βββ generate_summary (traced)
```
## Collecting outputs in a directory
The formatted outputs are written to a temporary directory and returned as a
`Dir` artifact:
```python
output_dir = tempfile.mkdtemp()
with open(os.path.join(output_dir, "report.md"), "w") as f:
f.write(markdown)
with open(os.path.join(output_dir, "report.html"), "w") as f:
f.write(html)
with open(os.path.join(output_dir, "summary.txt"), "w") as f:
f.write(summary)
return await Dir.from_local(output_dir)
```
The `Dir.from_local()` call uploads the directory to Flyte's
artifact storage, making it available to downstream tasks or applications.
## The batch pipeline
The batch pipeline processes multiple topics in parallel, demonstrating where
`ReusePolicy` truly shines:
```python
@driver_env.task
async def report_batch_pipeline(
topics: list[str],
max_iterations: int = 3,
quality_threshold: int = 8,
) -> list[Dir]:
"""
Generate reports for multiple topics in parallel.
This is where ReusePolicy shines: with N topics, each going through
up to max_iterations refinement cycles, the reusable container pool
handles potentially N Γ 7 LLM calls efficiently without cold starts.
Args:
topics: List of topics to write about
max_iterations: Maximum refinement cycles per topic
quality_threshold: Minimum quality score to accept
Returns:
List of directories, each containing a report's formatted outputs
"""
print(f"Starting batch pipeline for {len(topics)} topics...")
# Fan out: refine all reports in parallel
# Each refine_report makes 2-7 LLM calls, all hitting the reusable pool
with flyte.group("refine_all"):
reports = await asyncio.gather(*[
refine_report(topic, max_iterations, quality_threshold)
for topic in topics
])
print(f"All {len(reports)} reports refined, formatting outputs...")
# Fan out: format all reports in parallel
with flyte.group("format_all"):
outputs = await asyncio.gather(*[
format_outputs(report)
for report in reports
])
print(f"Batch pipeline complete! Generated {len(outputs)} reports.")
return outputs
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
### Pipeline flow
1. **Fan out refine_all**: Process all topics in parallel using `asyncio.gather`
2. **Fan out format_all**: Format all reports in parallel
3. **Return list of Dirs**: Each directory contains one report's outputs
With 5 topics, each making ~7 LLM calls, the reusable container pool handles
~35 LLM calls efficiently without cold starts.
## Running the pipeline
To run the batch pipeline:
```python
if __name__ == "__main__":
flyte.init_from_config()
# Multiple topics to generate reports for
topics = [
"The Impact of Large Language Models on Software Development",
"Edge Computing: Bringing AI to IoT Devices",
"Quantum Computing: Current State and Near-Term Applications",
"The Rise of Rust in Systems Programming",
"WebAssembly: The Future of Browser-Based Applications",
]
print(f"Submitting batch run for {len(topics)} topics...")
import sys
sys.stdout.flush()
# Run the batch pipeline - this will generate all reports in parallel,
# with the reusable container pool handling 5 topics Γ ~7 LLM calls each
run = flyte.run(
report_batch_pipeline,
topics=topics,
max_iterations=3,
quality_threshold=8,
)
print(f"Batch report generation run URL: {run.url}")
sys.stdout.flush()
print("Waiting for pipeline to complete (Ctrl+C to skip)...")
try:
run.wait()
print(f"Pipeline complete! Outputs: {run.outputs()}")
except KeyboardInterrupt:
print(f"\nSkipped waiting. Check status at: {run.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/generate.py*
```bash
uv run generate.py
```
The pipeline will:
1. Process all topics in parallel (each with iterative refinement)
2. Format all reports in parallel
3. Return a list of directories, each containing a report's outputs
## Cost optimization tips
### 1. Choose the right model
The example uses `gpt-4o-mini` for cost efficiency. For higher quality (at higher
cost), you could use `gpt-4o` or `gpt-4-turbo`:
```python
response = await client.chat.completions.create(
model="gpt-4o", # More capable, more expensive
...
)
```
### 2. Tune iteration parameters
Fewer iterations mean fewer API calls:
```python
run = flyte.run(
report_batch_pipeline,
topics=["Topic A", "Topic B"],
max_iterations=2, # Limit iterations
quality_threshold=7, # Accept slightly lower quality
)
```
### 3. Use caching effectively
The `cache="auto"` setting on the environment caches task outputs. Running the
same pipeline with the same inputs returns cached results instantly:
```python
llm_env = flyte.TaskEnvironment(
...
cache="auto", # Cache task outputs
)
```
### 4. Scale the batch
The batch pipeline already processes topics in parallel. To handle larger batches,
adjust the `ReusePolicy`:
```python
reusable=flyte.ReusePolicy(
replicas=4, # More containers for larger batches
concurrency=4, # Tasks per container
...
)
```
With 4 replicas Γ 4 concurrency = 16 slots, you can process 16 topics' refinement
tasks concurrently.
## Next steps
Learn how to [deploy a serving app](./serving-app) that connects to the pipeline
outputs and provides an interactive UI for report generation.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/advanced-project/serving-app ===
# Serving app
The final piece is a serving application that displays generated reports and
provides an interactive interface. This demonstrates how to connect apps to
pipeline outputs using `RunOutput`.
## App environment configuration
The `AppEnvironment` defines how the Streamlit application runs and connects to
the batch report pipeline:
```python
# Define the app environment
env = AppEnvironment(
name="report-generator-app",
description="Interactive report generator with AI-powered refinement",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit>=1.41.0",
),
args=["streamlit", "run", "app.py", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu=1, memory="2Gi"),
parameters=[
# Connect to the batch pipeline output (list of report directories)
Parameter(
name="reports",
value=RunOutput(
task_name="driver.report_batch_pipeline",
type="directory",
),
download=True,
env_var="REPORTS_PATH",
),
],
include=["app.py"],
requires_auth=False,
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/serve.py*
### Key configuration
| Setting | Purpose |
|---------|---------|
| `args` | Command to run the Streamlit app |
| `port` | Port the app listens on |
| `parameters` | Inputs to the app, including pipeline connections |
| `include` | Additional files to bundle with the app |
### Connecting to pipeline output with RunOutput
The `RunOutput` parameter connects the app to the batch pipeline's output:
```python
Parameter(
name="reports",
value=RunOutput(
task_name="driver.report_batch_pipeline",
type="directory",
),
download=True,
env_var="REPORTS_PATH",
)
```
This configuration:
1. **Finds the latest run** of `report_batch_pipeline` in the `driver` environment
2. **Downloads the output** to local storage (`download=True`)
3. **Sets an environment variable** with the path (`REPORTS_PATH`)
The app can then scan this directory for all generated reports.
## The Streamlit application
The app loads and displays all generated reports from the batch pipeline:
```python
def load_report_from_dir(report_dir: str) -> dict | None:
"""Load a single report from a directory."""
if not os.path.isdir(report_dir):
return None
report = {"path": report_dir, "name": os.path.basename(report_dir)}
md_path = os.path.join(report_dir, "report.md")
if os.path.exists(md_path):
with open(md_path) as f:
report["markdown"] = f.read()
html_path = os.path.join(report_dir, "report.html")
if os.path.exists(html_path):
with open(html_path) as f:
report["html"] = f.read()
summary_path = os.path.join(report_dir, "summary.txt")
if os.path.exists(summary_path):
with open(summary_path) as f:
report["summary"] = f.read()
# Only return if we found at least markdown content
return report if "markdown" in report else None
def load_all_reports() -> list[dict]:
"""Load all reports from the batch pipeline output."""
reports_path = os.environ.get("REPORTS_PATH")
if not reports_path or not os.path.exists(reports_path):
return []
reports = []
# Check if this is a single report directory (has report.md directly)
if os.path.exists(os.path.join(reports_path, "report.md")):
report = load_report_from_dir(reports_path)
if report:
report["name"] = "Report"
reports.append(report)
else:
# Batch output: scan subdirectories for reports
for entry in sorted(os.listdir(reports_path)):
entry_path = os.path.join(reports_path, entry)
report = load_report_from_dir(entry_path)
if report:
reports.append(report)
return reports
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/app.py*
### Displaying multiple reports
The app provides a sidebar for selecting between reports when multiple are available:
```python
reports = load_all_reports()
if reports:
# Sidebar for report selection if multiple reports
if len(reports) > 1:
st.sidebar.header("Select Report")
report_names = [f"Report {i+1}: {r['name']}" for i, r in enumerate(reports)]
selected_idx = st.sidebar.selectbox(
"Choose a report to view:",
range(len(reports)),
format_func=lambda i: report_names[i],
)
selected_report = reports[selected_idx]
st.sidebar.markdown(f"**Viewing {len(reports)} reports**")
else:
selected_report = reports[0]
st.header(f"Generated Report: {selected_report['name']}")
# Summary section
if "summary" in selected_report:
with st.expander("Executive Summary", expanded=True):
st.write(selected_report["summary"])
# Tabbed view for different formats
tab_md, tab_html = st.tabs(["Markdown", "HTML Preview"])
with tab_md:
st.markdown(selected_report.get("markdown", ""))
with tab_html:
if "html" in selected_report:
st.components.v1.html(selected_report["html"], height=600, scrolling=True)
# Download options
st.subheader("Download")
col1, col2, col3 = st.columns(3)
with col1:
if "markdown" in selected_report:
st.download_button(
label="Download Markdown",
data=selected_report["markdown"],
file_name="report.md",
mime="text/markdown",
)
with col2:
if "html" in selected_report:
st.download_button(
label="Download HTML",
data=selected_report["html"],
file_name="report.html",
mime="text/html",
)
with col3:
if "summary" in selected_report:
st.download_button(
label="Download Summary",
data=selected_report["summary"],
file_name="summary.txt",
mime="text/plain",
)
else:
st.info("No reports generated yet. Run the batch pipeline to create reports.")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/advanced-project/app.py*
Features:
- **Report selector**: Sidebar navigation when multiple reports exist
- **Executive summary**: Expandable section with key takeaways
- **Tabbed views**: Switch between Markdown and HTML preview
- **Download buttons**: Export in any format
### Generation instructions
The app includes instructions for generating new reports:
```python
st.divider()
st.header("Generate New Reports")
st.write("""
To generate reports, run the batch pipeline:
```bash
uv run generate.py
```
This generates reports for multiple topics in parallel, demonstrating
how ReusePolicy efficiently handles many concurrent LLM calls.
""")
# Show pipeline parameters info
with st.expander("Pipeline Parameters"):
st.markdown("""
**Available parameters:**
| Parameter | Default | Description |
|-----------|---------|-------------|
| `topics` | (required) | List of topics to write about |
| `max_iterations` | 3 | Maximum refinement cycles per topic |
| `quality_threshold` | 8 | Minimum score (1-10) to accept |
**Example:**
CODE5
""")
CODE6 python
if __name__ == "__main__":
flyte.init_from_config()
# Deploy the report generator app
print("Deploying report generator app...")
deployment = flyte.serve(env)
print(f"App deployed at: {deployment.url}")
CODE7 bash
uv run serve.py
CODE8 bash
uv run generate.py
CODE9 bash
uv run serve.py
```
3. **Access the app** at the provided URL and browse all generated reports
The app automatically picks up the latest pipeline run, so you can generate
new batches and always see the most recent results.
## Automatic updates with RunOutput
The `RunOutput` connection is evaluated at app startup. Each time the app
restarts or redeploys, it fetches the latest batch pipeline output.
For real-time updates without redeployment, you could:
1. Poll for new runs using the Flyte API
2. Implement a webhook that triggers app refresh
3. Use a database to track run status
## Complete example structure
Here's the full project structure:
CODE10
## Running the complete example
1. **Set up the secret**:
CODE11
2. **Run the pipeline**:
CODE12
3. **Deploy the app**:
CODE13
4. **Open the app URL** and view your generated report
## Summary
This example demonstrated:
| Feature | What it does |
|---------|--------------|
| `ReusePolicy` | Keeps containers warm for high-throughput batch processing |
| `@flyte.trace` | Checkpoints LLM calls for recovery and observability |
| `RetryStrategy` | Handles transient API failures gracefully |
| `flyte.group` | Organizes parallel batches and iterations in the UI |
| `asyncio.gather` | Fans out to process multiple topics concurrently |
| Pydantic models | Structured LLM outputs |
| `AppEnvironment` | Deploys interactive Streamlit apps |
| `RunOutput` | Connects apps to pipeline outputs |
These patterns form the foundation for building production-grade AI workflows
that are resilient, observable, and cost-efficient at scale.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2 ===
# From Flyte 1 to 2
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
Flyte 2 represents a fundamental shift in how Flyte workflows are written and executed.
## Pure Python execution
Write workflows in pure Python, enabling a more natural development experience and removing the constraints of a
domain-specific language (DSL).
### Sync Python
```
import flyte
env = flyte.TaskEnvironment("sync_example_env")
@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
def main(name: str) -> str:
for i in range(10):
hello_world(name)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/sync_example.py*
### Async Python
```
import asyncio
import flyte
env = flyte.TaskEnvironment("async_example_env")
@env.task
async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
async def main(name: str) -> str:
results = []
for i in range(10):
results.append(hello_world(name))
await asyncio.gather(*results)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async_example.py*
As you can see in the hello world example, workflows can be constructed at runtime, allowing for more flexible and
adaptive behavior. Flyte 2 supports:
- Python's asynchronous programming model to express parallelism.
- Python's native error handling with `try-except` to overridden configurations, like resource requests.
- Predefined static workflows when compile-time safety is critical.
## Simplified API
The new API is more intuitive, with fewer abstractions to learn and a focus on simplicity.
| Use case | Flyte 1 | Flyte 2 |
| ----------------------------- | --------------------------- | --------------------------------------- |
| Environment management | `N/A` | `TaskEnvironment` |
| Perform basic computation | `@task` | `@env.task` |
| Combine tasks into a workflow | `@workflow` | `@env.task` |
| Create dynamic workflows | `@dynamic` | `@env.task` |
| Fanout parallelism | `flytekit.map` | Python `for` loop with `asyncio.gather` |
| Conditional execution | `flytekit.conditional` | Python `if-elif-else` |
| Catching workflow failures | `@workflow(on_failure=...)` | Python `try-except` |
There is no `@workflow` decorator. Instead, "workflows" are authored through a pattern of tasks calling tasks.
Tasks are defined within environments, which encapsulate the context and resources needed for execution.
## Fine-grained reproducibility and recoverability
As in Flyte 1, Flyte 2 supports caching at the task level (via `@env.task(cache=...)`), but it further enables recovery at the finer-grained, sub-task level through a feature called tracing (via `@flyte.trace`).
```
import flyte
env = flyte.TaskEnvironment(name="trace_example_env")
@flyte.trace
async def call_llm(prompt: str) -> str:
return "Initial response from LLM"
@env.task
async def finalize_output(output: str) -> str:
return "Finalized output"
@env.task(cache=flyte.Cache(behavior="auto"))
async def main(prompt: str) -> str:
output = await call_llm(prompt)
output = await finalize_output(output)
return output
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, prompt="Prompt to LLM")
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/trace.py*
Here `call_llm` runs in the same container as `main` and acts as an automated checkpoint with full observability in the UI.
If the task fails due to a system error (e.g., node preemption or infrastructure failure), Flyte can recover and replay from the
last successful trace rather than restarting from the beginning.
Note that tracing is distinct from caching: traces are recovered only if there is a system failure
whereas with cached outputs are persisted for reuse across separate runs.
## Improved remote functionality
Flyte 2 provides full management of the workflow lifecycle through a standardized API through the CLI and the Python SDK.
| Use case | CLI | Python SDK |
| ------------- | ------------------ | ------------------- |
| Run a task | `flyte run ...` | `flyte.run(...)` |
| Deploy a task | `flyte deploy ...` | `flyte.deploy(...)` |
You can also fetch and run remote (previously deployed) tasks within the course of a running workflow.
```
import flyte
from flyte import remote
env_1 = flyte.TaskEnvironment(name="env_1")
env_2 = flyte.TaskEnvironment(name="env_2")
env_1.add_dependency(env_2)
@env_2.task
async def remote_task(x: str) -> str:
return "Remote task processed: " + x
@env_1.task
async def main() -> str:
remote_task_ref = remote.Task.get("env_2.remote_task", auto_version="latest")
r = await remote_task_ref(x="Hello")
return "main called remote and recieved: " + r
if __name__ == "__main__":
flyte.init_from_config()
d = flyte.deploy(env_1)
print(d[0].summary_repr())
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/remote.py*
## Native Notebook support
Author and run workflows and fetch workflow metadata (I/O and logs) directly from Jupyter notebooks.

## Subpages
- **From Flyte 1 to 2 > Pure Python**
- **From Flyte 1 to 2 > Asynchronous model**
- **From Flyte 1 to 2 > Migration from Flyte 1 to Flyte 2**
- **From Flyte 1 to 2 > Considerations**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/pure-python ===
# Pure Python
Flyte 2 introduces a new way of writing workflows that is based on pure Python, removing the constraints of a domain-specific language (DSL) and enabling full use of Python's capabilities.
## From `@workflow` DSL to pure Python
| Flyte 1 | Flyte 2 |
| --- | --- |
| `@workflow`-decorated functions are constrained to a subset of Python for defining a static directed acyclic graph (DAG) of tasks. | **No more `@workflow` decorator**: Everything is a `@env.task`, so your top-level βworkflowβ is simply a task that calls other tasks. |
| `@task`-decorated functions could leverage the full power of Python, but only within individual container executions. | `@env.task`s can call other `@env.task`s and be used to construct workflows with dynamic structures using loops, conditionals, try/except, and any Python construct anywhere. |
| Workflows were compiled into static DAGs at registration time, with tasks as the nodes and the DSL defining the structure. | Workflows are simply tasks that call other tasks. Compile-time safety will be available in the future as `compiled_task`. |
### Flyte 1
```python
import flytekit
image = flytekit.ImageSpec(
name="hello-world-image",
packages=["requests"],
)
@flytekit.task(container_image=image)
def mean(data: list[float]) -> float:
return sum(list) / len(list)
@flytekit.workflow
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is not allowed
# output = output / 100
# β if/else is not allowed
# if output < 0:
# raise ValueError("Output cannot be negative")
return output
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/pure-python/flyte_1.py*
### Flyte 2
```
import flyte
env = flyte.TaskEnvironment(
"hello_world",
image=flyte.Image.from_debian_base().with_pip_packages("requests"),
)
@env.task
def mean(data: list[float]) -> float:
return sum(data) / len(data)
@env.task
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is allowed
output = output / 100
# β if/else is allowed
if output < 0:
raise ValueError("Output cannot be negative")
return output
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/pure-python/flyte_2.py*
These fundamental changes bring several transformative benefits:
- **Flexibility**: Harness the complete Python language for workflow definition, including all control flow constructs previously forbidden in workflows.
- **Dynamic workflows**: Create workflows that adapt to runtime conditions, handle variable data structures, and make decisions based on intermediate results.
- **Natural error handling**: Use standard Python `try`/`except` patterns throughout your workflows, making them more robust and easier to debug.
- **Intuitive composability**: Build complex workflows by naturally composing Python functions, following familiar patterns that any Python developer understands.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/async ===
# Asynchronous model
## Why we need an async model
The shift to an asynchronous model in Flyte 2 is driven by the need for more efficient and flexible workflow execution.
We believe, in particular, that with the rise of the agentic AI pattern, asynchronous programming has become an essential part of AI/ML engineering and data science toolkit.
With Flyte 2, the entire framework is now written with async constructs, allowing for:
- Seamless overlapping of I/O and independent external operations.
- Composing multiple tasks and external tool invocations within the same Python process.
- Native support of streaming operations for data, observability and downstream invocations.
It is also a natural fit for the expression parallelism in workflows.
### Understanding concurrency vs. parallelism
**Concurrency** means running multiple tasks at once. This can be achieved by interleaving execution on a single thread (switching between tasks when one is waiting) or by true **parallelism**βexecuting tasks truly simultaneously across multiple cores or machines. Parallelism is a form of concurrency, but concurrency doesn't require parallelism.
### Python's async evolution
Python's asynchronous programming capabilities have evolved significantly:
- **The GIL challenge**: Python's Global Interpreter Lock (GIL) traditionally prevented true parallelism for CPU-bound tasks, limiting threading effectiveness to I/O-bound operations.
- **Traditional solutions**:
- `multiprocessing`: Created separate processes to sidestep the GIL, effective but resource-intensive
- `threading`: Useful for I/O-bound tasks where the GIL could be released during external operations
- **The async revolution**: The `asyncio` library introduced cooperative multitasking within a single thread, using an event loop to manage multiple tasks efficiently.
### Parallelism in Flyte 1 vs Flyte 2
| | Flyte 1 | Flyte 2 |
| --- | --- | --- |
| Parallelism | The workflow DSL automatically parallelized tasks that weren't dependent on each other. The `map` operator allowed running a task multiple times in parallel with different inputs. | Leverages Python's `asyncio` as the primary mechanism for expressing parallelism, but with a crucial difference: **the Flyte orchestrator acts as the event loop**, managing task execution across distributed infrastructure. |
### Core async concepts
- **`async def`**: Declares a function as a coroutine. When called, it returns a coroutine object managed by the event loop rather than executing immediately.
- **`await`**: Pauses coroutine execution and passes control back to the event loop.
In standard Python, this enables other tasks to run while waiting for I/O operations.
In Flyte 2, it signals where tasks can be executed in parallel.
- **`asyncio.gather`**: The primary tool for concurrent execution.
In standard Python, it schedules multiple awaitable objects to run concurrently within a single event loop.
In Flyte 2, it signals to the orchestrator that these tasks can be distributed across separate compute resources.
#### A practical example
Consider this pattern for parallel data processing:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("data_pipeline")
@env.task
async def process_chunk(chunk_id: int, data: str) -> str:
# This could be any computational work - CPU or I/O bound
await asyncio.sleep(1) # Simulating work
return f"Processed chunk {chunk_id}: {data}"
@env.task
async def parallel_pipeline(data_chunks: list[str]) -> list[str]:
# Create coroutines for all chunks
tasks = []
for i, chunk in enumerate(data_chunks):
tasks.append(process_chunk(i, chunk))
# Execute all chunks in parallel
results = await asyncio.gather(*tasks)
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py*
In standard Python, this would provide concurrency benefits primarily for I/O-bound operations.
In Flyte 2, the orchestrator schedules each `process_chunk` task on separate Kubernetes pods or configured plugins, achieving true parallelism for any type of work.
### True parallelism for all workloads
This is where Flyte 2's approach becomes revolutionary: **async syntax is not just for I/O-bound operations**.
The `async`/`await` syntax becomes a powerful way to declare your workflow's parallel structure for any type of computation.
When Flyte's orchestrator encounters `await asyncio.gather(...)`, it understands that these tasks are independent and can be executed simultaneously across different compute resources.
This means you achieve true parallelism for:
- **CPU-bound computations**: Heavy mathematical operations, model training, data transformations
- **I/O-bound operations**: Database queries, API calls, file operations
- **Mixed workloads**: Any combination of computational and I/O tasks
The Flyte platform handles the complex orchestration while you express parallelism using intuitive `async` syntax.
## Calling sync tasks from async tasks
### Synchronous task support
Since many existing codebases use synchronous functions, Flyte 2 provides synchronous support. Under the hood, Flyte automatically "asyncifies" synchronous functions, wrapping them to participate seamlessly in the async execution model.
You don't need to rewrite existing code, just leverage the `.aio()` method when calling sync tasks from async contexts:
```
@env.task
def legacy_computation(x: int) -> int:
# Existing synchronous function works unchanged
return x * x + 2 * x + 1
@env.task
async def modern_workflow(numbers: list[int]) -> list[int]:
# Call sync tasks from async context using .aio()
tasks = []
for num in numbers:
tasks.append(legacy_computation.aio(num))
results = await asyncio.gather(*tasks)
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py*
### The `flyte.map` function: Familiar patterns
For scenarios that previously used Flyte 1's `map` operation, Flyte 2 provides `flyte.map` as a direct replacement.
The new `flyte.map` can be used either in synchronous or asynchronous contexts, allowing you to express parallelism without changing your existing patterns.
### Sync Map
```
@env.task
def sync_map_example(n: int) -> list[str]:
# Synchronous version for easier migration
results = []
for result in flyte.map(process_item, range(n)):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py*
### Async Map
```
@env.task
async def async_map_example(n: int) -> list[str]:
# Async version using flyte.map - exact pattern from SDK examples
results = []
async for result in flyte.map.aio(process_item, range(n), return_exceptions=True):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py*
The `flyte.map` function provides:
- **Dual interfaces**: `flyte.map.aio()` for async contexts, `flyte.map()` for sync contexts.
- **Built-in error handling**: `return_exceptions` parameter for graceful failure handling. This matches the `asyncio.gather` interface,
allowing you to decide how to handle errors.
If you are coming from Flyte 1, it allows you to replace `min_success_ratio` in a more flexible way.
- **Automatic UI grouping**: Creates logical groups for better workflow visualization.
- **Concurrency control**: Optional limits for resource management.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/migration ===
# Migration from Flyte 1 to Flyte 2
> **π Note**
>
> For comprehensive migration reference with detailed API mappings, parameter tables, and complete examples, see [Migration from Flyte 1](../../api-reference/migration/_index) in the Reference section.
> An LLM-optimized bundle of the full migration reference is available at [`section.md`](../../api-reference/migration/section.md).
You can migrate from Flyte 1 to Flyte 2 by following the steps below:
### 1. Move task configuration to a `TaskEnvironment` object
Instead of configuring the image, hardware resources, and so forth, directly in the task decorator. You configure it in `TaskEnvironment` object. For example:
```python
env = flyte.TaskEnvironment(name="my_task_env")
```
### 2. Replace workflow decorators
Then, you replace the `@workflow` and `@task` decorators with `@env.task` decorators.
### Flyte 1
Here's a simple hello world example with fanout.
```python
import flytekit
@flytekit.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@flytekit.workflow
def main(names: list[str]) -> list[str]:
return flytekit.map(hello_world)(names)
```
### Flyte 2 Sync
Change all the decorators to `@env.task` and swap out `flytekit.map` with `flyte.map`.
Notice that `flyte.map` is a drop-in replacement for Python's built-in `map` function.
```diff
-@flytekit.task
+@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
-@flytekit.workflow
+@env.task
def main(names: list[str]) -> list[str]:
return flyte.map(hello_world, names)
```
> **π Note**
>
> Note that the reason our task decorator uses `env` is simply because that is the variable to which we assigned the `TaskEnvironment` above.
### Flyte 2 Async
To take advantage of full concurrency (not just parallelism), use Python async
syntax and the `asyncio` standard library to implement fa-out.
```diff
+import asyncio
@env.task
-def hello_world(name: str) -> str:
+async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
-def main(names: list[str]) -> list[str]:
+async def main(names: list[str]) -> list[str]:
- return flyte.map(hello_world, names)
+ return await asyncio.gather(*[hello_world(name) for name in names])
```
> **π Note**
>
> To use Python async syntax, you need to:
> - Use `asyncio.gather()` or `flyte.map()` for parallel execution
> - Add `async`/`await` keywords where you want parallelism
> - Keep existing sync task functions unchanged
>
> Learn more about about the benefits of async in the [Asynchronous Model](./async) guide.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/considerations ===
# Considerations
Flyte 2 represents a substantial change from Flyte 1.
Each Python-based task action has the ability to act as its own engine, kicking off sub-actions, and assembling the outputs, passing them to yet other sub-actions and such.
While this model of execution comes with an enormous amount of flexibility, that flexibility does warrant some caveats to keep in mind when authoring your tasks.
## Non-deterministic behavior
When a task launches another task, a new Action ID is determined.
This ID is a hash of the inputs to the task, the task definition itself, along with some other information.
The fact that this ID is consistently hashed is important when it comes to things like recovery and replay.
For example, assume you have the following tasks
```python
@env.task
async def t1():
val = get_int_input()
await t2(int=val)
@env.task
async def t2(val: int): ...
```
If you run `t1`, and it launches the downstream `t2` task, and then the pod executing `t1` fails, when Flyte restarts `t1` it will automatically detect that `t2` is still running and will just use that.
If `t2` ends up finishing in the interim, those results would just be used.
However, if you introduce non-determinism into the picture, then that guarantee is no longer there.
To give a contrived example:
```python
@env.task
async def t1():
val = get_int_input()
now = datetime.now()
if now.second % 2 == 0:
await t2(int=val)
else:
await t3(int=val)
```
Here, depending on what time it is, either `t2` or `t3` may end up running.
In the earlier scenario, if `t1` crashes unexpectedly, and Flyte retries the execution, a different downstream task may get kicked off instead.
### Dealing with non-determinism
As a developer, the best way to manage non-deterministic behavior (if it is unavoidable) is to be able to observe it and see exactly what is happening in your code. Flyte 2 provides precisely the tool needed to enable this: Traces.
With this feature you decorate the sub-task functions in your code with `@trace`, enabling checkpointing, reproducibility and recovery at a fine-grained level. See [Traces](../task-programming/traces) for more details.
## Type safety
In Flyte 1, the top-level workflow was defined by a Python-like DSL that was compiled into a static DAG composed of tasks, each of which was, internally, defined in real Python.
The system was able to guarantee type safety across task boundaries because the task definitions were static and the inputs and outputs were defined in a way that Flytekit could validate them.
In Flyte 2, the top-level workflow is defined by Python code that runs at runtime (unless using a compiled task).
This means that the system can no longer guarantee type safety at the workflow level.
Happily, the Python ecosystem has evolved considerably since Flyte 1, and Python type hints are now a standard way to define types.
Consequently, in Flyte 2, developers should use Python type hints and type checkers like `mypy` to ensure type safety at all levels, including the top-most task (i.e., the "workflow" level).
## No global state
A core principle of Flyte 2 (that is also shared with Flyte 1) is that you should not try to maintain global state across your workflow.
It will not be translated across tasks containers,
In a single process Python program, global variables are available across functions.
In the distributed execution model of Flyte, each task runs in its own container, and each container is isolated from the others.
If there is some state that needs to be preserved, it must be reconstructable through repeated deterministic execution.
## Driver pod requirements
Tasks don't have to kick off downstream tasks of course and may themselves represent a leaf level atomic unit of compute.
However, when tasks do run other tasks, and more so if they assemble the outputs of those other tasks, then that parent task becomes a driver
pod of sorts.
In Flyte 1, this assembling of intermediate outputs was done by Flyte Propeller.
In 2, it's done by the parent task.
This means that the pod running your parent task must be appropriately sized, and should ideally not be CPU-bound, otherwise it slow down downstream evaluation and kickoff of tasks.
For example, if you had this also scenario,
```python
@env.task
async def t_main():
await t1()
local_cpu_intensive_function()
await t2()
```
The pod running `t_main` will hang in between tasks `t1` and `t2`. Your parent tasks should ideally focus only on orchestration.
## OOM risk from materialized I/O
Something maybe more nuanced to keep in mind is that if you're not using the soon-to-be-released ref mode, outputs are actually
materialized. That is, if you have the following scenario,
```python
@env.task
async def produce_1gb_list() -> List[float]: ...
@env.task
async def t1():
list_floats = produce_1gb_list()
t2(floats=list_floats)
```
The pod running `t1` needs to have memory to handle that 1 GB of floats. Those numbers will be materialized in that pod's memory.
This can lead to out of memory issues.
Note that `flyte.io.File`, `flyte.io.Dir` and `flyte.io.DataFrame` will not suffer from this because while those are materialized, they're only materialized as pointers to offloaded data, so their memory footprint is much lower.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration ===
# Configure tasks
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
As we saw in **Quickstart**, you can run any Python function as a task in Flyte just by decorating it with `@env.task`.
This allows you to run your Python code in a distributed manner, with each function running in its own container.
Flyte manages the spinning up of the containers, the execution of the code, and the passing of data between the tasks.
The simplest possible case is a `TaskEnvironment` with only a `name` parameter, and an `env.task` decorator, with no parameters:
```
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name:str) -> str:
return f"Hello {name}!"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py*
> [!NOTE]
> Notice how the `TaskEnvironment` is assigned to the variable `env` and then that variable is
> used in the `@env.task`. This is what connects the `TaskEnvironment` to the task definition.
>
> In the following we will often use `@env.task` generically to refer to the decorator,
> but it is important to remember that it is actually a decorator attached to a specific
> `TaskEnvironment` object, and the `env` part can be any variable name you like.
This will run your task in the default container environment with default settings.
But, of course, one of the key advantages of Flyte is the ability to control the software environment, hardware environment, and other execution parameters for each task, right in your Python code.
In this section we will explore the various configuration options available for tasks in Flyte.
## Task configuration levels
Task configuration is done at three levels. From most general to most specific, they are:
* The `TaskEnvironment` level: setting parameters when defining the `TaskEnvironment` object.
* The `@env.task` decorator level: Setting parameters in the `@env.task` decorator when defining a task function.
* The task invocation level: Using the `task.override()` method when invoking task execution.
Each level has its own set of parameters, and some parameters are shared across levels.
For shared parameters, the more specific level will override the more general one.
### Example
Here is an example of how these levels work together, showing each level with all available parameters:
```
# Level 1: TaskEnvironment - Base configuration
env_2 = flyte.TaskEnvironment(
name="data_processing_env",
image=flyte.Image.from_debian_base(),
resources=flyte.Resources(cpu=1, memory="512Mi"),
env_vars={"MY_VAR": "value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY"),
cache="disable",
# pod_template=my_pod_template,
# reusable=flyte.ReusePolicy(replicas=2, idle_ttl=300),
depends_on=[another_env],
description="Data processing task environment",
# plugin_config=my_plugin_config
)
# Level 2: Decorator - Override some environment settings
@env_2.task(
short_name="process",
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_2"),
cache="auto",
# pod_template=my_pod_template,
report=True,
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60,
docs="This task processes data and generates a report."
)
async def process_data(data_path: str) -> str:
return f"Processed {data_path}"
@env_2.task
async def invoke_process_data() -> str:
result = await process_data.override(
resources=flyte.Resources(cpu=4, memory="2Gi"),
env_vars={"MY_VAR": "new_value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_3"),
cache="auto",
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60
)("input.csv")
return result
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py*
### Parameter interaction
Here is an overview of all task configuration parameters available at each level and how they interact:
| Parameter | `TaskEnvironment` | `@env.task` decorator | `override` on task invocation |
|-------------------------|--------------------|----------------------------|-------------------------------|
| **Configure tasks > Additional task settings** | β Yes (required) | β No | β No |
| **Configure tasks > Additional task settings** | β No | β Yes | β Yes |
| **Configure tasks > Container images** | β Yes | β No | β No |
| **Configure tasks > Resources** | β Yes | β No | β Yes (if not `reusable`) |
| **Configure tasks > Additional task settings > Environment variables** | β Yes | β No | β Yes (if not `reusable`) |
| **Configure tasks > Secrets** | β Yes | β No | β Yes (if not `reusable`) |
| **Configure tasks > Caching** | β Yes | β Yes | β Yes |
| **Configure tasks > Pod templates** | β Yes | β Yes | β Yes |
| **Configure tasks > Reusable containers** | β Yes | β No | β Yes |
| **Configure tasks > Multiple environments** | β Yes | β No | β No |
| **Configure tasks > Additional task settings** | β Yes | β No | β No |
| **Configure tasks > Task plugins** | β Yes | β No | β No |
| **Configure tasks > Additional task settings > Naming and metadata > `report`** | β No | β Yes | β No |
| **Configure tasks > Additional task settings > Inline I/O threshold** | β No | β Yes | β Yes |
| **Configure tasks > Retries and timeouts** | β No | β Yes | β Yes |
| **Configure tasks > Retries and timeouts** | β No | β Yes | β Yes |
| **Configure tasks > Triggers** | β No | β Yes | β No |
| **Configure tasks > Additional task settings > Naming and metadata > `links`** | β No | β Yes | β Yes |
| **Configure tasks > Interruptible tasks and queues** | β Yes | β Yes | β Yes |
| **Configure tasks > Interruptible tasks and queues** | β Yes | β Yes | β Yes |
| **Configure tasks > Additional task settings > Naming and metadata > `docs`** | β No | β Yes | β No |
## Task configuration parameters
Each parameter is documented in detail on its dedicated page or in the API reference. For full type signatures and constraints, see the **Flyte SDK > Packages > flyte > TaskEnvironment**.
| Parameter | Details |
|-----------|---------|
| **name**, **short_name**, **description**, **docs** | **Configure tasks > Additional task settings** |
| **image** | **Configure tasks > Container images** • **Flyte SDK > Packages > flyte > Image** |
| **resources** | **Configure tasks > Resources** • **Flyte SDK > Packages > flyte > Resources** |
| **env_vars** | **Configure tasks > Additional task settings > Environment variables** |
| **secrets** | **Configure tasks > Secrets** • **Flyte SDK > Packages > flyte > Secret** |
| **cache** | **Configure tasks > Caching** • **Flyte SDK > Packages > flyte > Cache** |
| **pod_template** | **Configure tasks > Pod templates** • **Flyte SDK > Packages > flyte > PodTemplate** |
| **reusable** | **Configure tasks > Reusable containers** • **Flyte SDK > Packages > flyte > ReusePolicy** |
| **depends_on** | **Configure tasks > Multiple environments** |
| **plugin_config** | **Configure tasks > Task plugins** |
| **report** | **Configure tasks > Additional task settings > Naming and metadata > `report`** |
| **max_inline_io_bytes** | **Configure tasks > Additional task settings > Inline I/O threshold** |
| **retries**, **timeout** | **Configure tasks > Retries and timeouts** • **Flyte SDK > Packages > flyte > RetryStrategy**, **Flyte SDK > Packages > flyte > Timeout** API refs |
| **triggers** | **Configure tasks > Triggers** • **Flyte SDK > Packages > flyte > Trigger** |
| **links** | **Configure tasks > Additional task settings > Naming and metadata > `links`** |
| **interruptible**, **queue** | **Configure tasks > Interruptible tasks and queues** |
## Subpages
- **Configure tasks > Container images**
- **Configure tasks > Resources**
- **Configure tasks > Secrets**
- **Configure tasks > Caching**
- **Configure tasks > Reusable containers**
- **Configure tasks > Pod templates**
- **Configure tasks > Multiple environments**
- **Configure tasks > Retries and timeouts**
- **Configure tasks > Triggers**
- **Configure tasks > Interruptible tasks and queues**
- **Configure tasks > Task plugins**
- **Configure tasks > Additional task settings**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/container-images ===
# Container images
The `image` parameter of the [`TaskEnvironment`](../../api-reference/flyte-sdk/packages/flyte/taskenvironment) is used to specify a container image.
Every task defined using that `TaskEnvironment` will run in a container based on that image.
If a `TaskEnvironment` does not specify an `image`, it will use the default Flyte image ([`ghcr.io/flyteorg/flyte:py{python-version}-v{flyte_version}`](https://github.com/orgs/flyteorg/packages/container/package/flyte)).
## Specifying your own image directly
You can directly reference an image by URL in the `image` parameter, like this:
```python
env = flyte.TaskEnvironment(
name="my_task_env",
image="docker.io/myorg/myimage:mytag"
)
```
This works well if you have a pre-built image available in a public registry like Docker Hub or in a private registry that your Union/Flyte instance can access.
## Specifying your own image with the `flyte.Image` object
You can also construct an image programmatically using the `flyte.Image` object.
The `flyte.Image` object provides a fluent interface for building container images: start with a `from_*` base constructor, then customize with `with_*` methods. Each method returns a new immutable `Image`.
For a complete list of all available methods and their parameters, see the [`Image` API reference](../../api-reference/flyte-sdk/packages/flyte/image).
Here are some examples of the most common patterns for building images with `flyte.Image`.
## Example: Defining a custom image with `Image.from_debian_base`
The `[[Image.from_debian_base()]]` provides the default Flyte image as the base.
This image is itself based on the official Python Docker image (specifically `python:{version}-slim-bookworm`) with the addition of the Flyte SDK pre-installed.
Starting there, you can layer additional features onto your image.
For example:
```python
import flyte
import numpy as np
# Define the task environment
env = flyte.TaskEnvironment(
name="my_env",
image = (
flyte.Image.from_debian_base(
name="my-image",
python_version=(3, 13)
# registry="registry.example.com/my-org" # Only needed for local builds
)
.with_apt_packages("libopenblas-dev")
.with_pip_packages("numpy")
.with_env_vars({"OMP_NUM_THREADS": "4"})
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_debian_base.py*
> [!NOTE]
> The `registry` parameter is only needed if you are building the image locally. It is not required when using the Union backend `ImageBuilder`.
> See **Configure tasks > Container images > Image building** for more details.
> [!NOTE]
> Images built with `[[Image.from_debian_base()]]` do not include CA certificates by default, which can cause TLS
> validation errors and block access to HTTPS-based storage such as Amazon S3. Libraries like Polars (e.g., `polars.scan_parquet()`) are particularly affected.
> **Solution:** Add `"ca-certificates"` using `.with_apt_packages()` in your image definition.
## Example: Defining an image based on uv script metadata
Another common technique for defining an image is to use [`uv` inline script metadata](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies) to specify your dependencies right in your Python file and then use the `flyte.Image.from_uv_script()` method to create a `flyte.Image` object.
The `from_uv_script` method starts with the default Flyte image and adds the dependencies specified in the `uv` metadata.
For example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "numpy"
# ]
# main = "main"
# params = "x_list=[1,2,3,4,5,6,7,8,9,10]"
# ///
import flyte
import numpy as np
env = flyte.TaskEnvironment(
name="my_env",
image=flyte.Image.from_uv_script(
__file__,
name="my-image"
# registry="registry.example.com/my-org" # Only needed for local builds
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_uv_script.py*
The advantage of this approach is that the dependencies used when running a script locally and when running it on the Flyte/Union backend are always the same (as long as you use `uv` to run your scripts locally).
This means you can develop and test your scripts in a consistent environment, reducing the chances of encountering issues when deploying to the backend.
In the above example you can see how to use `flyte.init_from_config()` for remote runs and `flyte.init()` for local runs.
Uncomment the `flyte.init()` line (and comment out `flyte.init_from_config()`) to enable local runs.
Do the opposite to enable remote runs.
> [!NOTE]
> When using `uv` metadata in this way, be sure to include the `flyte` package in your `uv` script dependencies.
> This will ensure that `flyte` is installed when running the script locally using `uv run`.
> When running on the Flyte/Union backend, the `flyte` package from the uv script dependencies will overwrite the one included automatically from the default Flyte image.
## Image building
There are two ways that the image can be built:
* If you are running a Flyte OSS instance then the image will be built locally on your machine and pushed to the container registry you specified in the `Image` definition.
* If you are running a Union instance, the image can be built locally, as with Flyte OSS, or using the Union `ImageBuilder`, which runs remotely on Union's infrastructure.
### Configuring the `builder`
[Earlier](../connecting-to-a-cluster), we discussed the `image.builder` property in the `config.yaml`.
For Flyte OSS instances, this property must be set to `local`.
For Union instances, this property can be set to `remote` to use the Union `ImageBuilder`, or `local` to build the image locally on your machine.
### Local image building
When `image.builder` in the `config.yaml` is set to `local`, `flyte.run()` does the following:
* Builds the Docker image using your local Docker installation, installing the dependencies specified in the `uv` inline script metadata.
* Pushes the image to the container registry you specified.
* Deploys your code to the backend.
* Kicks off the execution of your workflow
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
> [!NOTE]
> Above, we used `registry="ghcr.io/my_gh_org"`.
>
> Be sure to change `ghcr.io/my_gh_org` to the URL of your actual container registry.
You must ensure that:
* Docker is running on your local machine.
* You have successfully run `docker login` to that registry from your local machine (For example GitHub uses the syntax `echo $GITHUB_TOKEN | docker login ghcr.io -u USERNAME --password-stdin`)
* Your Union/Flyte installation has read access to that registry.
> [!NOTE]
> If you are using the GitHub container registry (`ghcr.io`)
> note that images pushed there are private by default.
> You may need to go to the image URI, click **Package Settings**, and change the visibility to public in order to access the image.
>
> Other registries (such as Docker Hub) require that you pre-create the image repository before pushing the image.
> In that case you can set it to public when you create it.
>
> Public images are on the public internet and should only be used for testing purposes.
> Do not place proprietary code in public images.
### Remote `ImageBuilder`
`ImageBuilder` is a service provided by Union that builds container images on Union's infrastructure and provides an internal container registry for storing the built images.
When `image.builder` in the `config.yaml` is set to `remote` (and you are running Union.ai), `flyte.run()` does the following:
* Builds the Docker image on your Union instance with `ImageBuilder`.
* Pushes the image to a registry
* If you did not specify a `registry` in the `Image` definition, it pushes to the internal registry in your Union instance.
* If you did specify a `registry`, it pushes to that registry. Be sure to also set the `registry_secret` parameter in the `Image` definition to enable `ImageBuilder` to authenticate to that registry (see **Configure tasks > Container images > Image building > Remote `ImageBuilder` > ImageBuilder with external registries**).
* Deploys your code to the backend.
* Kicks off the execution of your workflow.
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
There is no set up of Docker nor any other local configuration required on your part.
> [!NOTE]
> The Flyte SDK checks whether the image builder is enabled for your cluster by verifying that the `image_build` task is deployed in the `system` project within the `production` domain.
> If you are using custom roles and policies, ensure that users are granted the `view_flyte_inventory` action for the `production/system` project-domain pair.
> See the [V1 user management documentation](/docs/v1/byoc//user-guide/administration/user-management) for more details on creating and assigning custom roles and policies (V2 user management currently works identically to V1).
#### ImageBuilder with external registries
If you are want to push the images built by `ImageBuilder` to an external registry, you can do this by setting the `registry` parameter in the `Image` object.
You will also need to set the `registry_secret` parameter to provide the secret needed to push and pull images to the private registry.
For example:
```python
# Add registry credentials so the Union remote builder can pull the base image
# and push the resulting image to your private registry.
image=flyte.Image.from_debian_base(
name="my-image",
base_image="registry.example.com/my-org/my-private-image:latest",
registry="registry.example.com/my-org"
registry_secret="my-secret"
)
# Reference the same secret in the TaskEnvironment so Flyte can pull the image at runtime.
env = flyte.TaskEnvironment(
name="my_task_env",
image=image,
secrets="my-secret"
)
```
The value of the `registry_secret` parameter must be the name of a Flyte secret of type `image_pull` that contains the credentials needed to access the private registry. It must match the name specified in the `secrets` parameter of the `TaskEnvironment` so that Flyte can use it to pull the image at runtime.
To create an `image_pull` secret for the remote builder and the task environment, run the following command:
```bash
flyte create secret --type image_pull my-secret --from-file ~/.docker/config.json
```
The format of this secret matches the standard Kubernetes [image pull secret](https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/#log-in-to-docker-hub), and should look like this:
```json
{
"auths": {
"registry.example.com": {
"auth": "base64-encoded-auth"
}
}
}
```
> [!NOTE]
> The `auth` field contains the base64-encoded credentials for your registry (username and password or token).
### Install private PyPI packages
To install Python packages from a private PyPI index (for example, from GitHub), you can mount a secret to the image layer.
This allows your build to authenticate securely during dependency installation.
For example:
```python
private_package = "git+https://$GITHUB_PAT@github.com/pingsutw/flytex.git@2e20a2acebfc3877d84af643fdd768edea41d533"
image = (
Image.from_debian_base()
.with_apt_packages("git")
.with_pip_packages(private_package, pre=True, secret_mounts=Secret("GITHUB_PAT"))
)
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/resources ===
# Resources
Task resources specify the computational limits and requests (CPU, memory, GPU, storage) that will be allocated to each task's container during execution.
To specify resource requirements for your task, instantiate a `Resources` object with the desired parameters and assign it to either
the `resources` parameter of the `TaskEnvironment` or the `resources` parameter of the `override` function (for invocation overrides).
Every task defined using that `TaskEnvironment` will run with the specified resources.
If a specific task has its own `resources` defined in the decorator, it will override the environment's resources for that task only.
If neither `TaskEnvironment` nor the task decorator specifies `resources`, the default resource allocation will be used.
## Resources data class
For the full class definition, parameter types, and accepted formats, see the [`Resources` API reference](../../api-reference/flyte-sdk/packages/flyte/resources).
The main parameters are:
- **`cpu`**: CPU allocation β number, string (`"500m"`), or `(request, limit)` tuple.
- **`memory`**: Memory with Kubernetes units β `"4Gi"`, or `(request, limit)` tuple.
- **`gpu`**: GPU allocation β `"A100:2"`, integer count, or `GPU()`/`TPU()`/`Device()` for advanced config.
- **`disk`**: Ephemeral storage β `"10Gi"`.
- **`shm`**: Shared memory β `"1Gi"` or `"auto"`.
## Examples
### Usage in TaskEnvironment
Here's a complete example of defining a TaskEnvironment with resource specifications for a machine learning training workload:
```
import flyte
# Define a TaskEnvironment for ML training tasks
env = flyte.TaskEnvironment(
name="ml-training",
resources=flyte.Resources(
cpu=("2", "4"), # Request 2 cores, allow up to 4 cores for scaling
memory=("2Gi", "12Gi"), # Request 2 GiB, allow up to 12 GiB for large datasets
disk="50Gi", # 50 GiB ephemeral storage for checkpoints
shm="8Gi" # 8 GiB shared memory for efficient data loading
)
)
# Use the environment for tasks
@env.task
async def train_model(dataset_path: str) -> str:
# This task will run with flexible resource allocation
return "model trained"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py*
### Usage in a task-specific override
```
# Demonstrate resource override at task invocation level
@env.task
async def heavy_training_task() -> str:
return "heavy model trained with overridden resources"
@env.task
async def main():
# Task using environment-level resources
result = await train_model("data.csv")
print(result)
# Task with overridden resources at invocation time
result = await heavy_training_task.override(
resources=flyte.Resources(
cpu="4",
memory="24Gi",
disk="100Gi",
shm="16Gi"
)
)()
print(result)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py*
## Resource types
### CPU resources
CPU can be specified in several formats:
```python
# String formats (Kubernetes-style)
flyte.Resources(cpu="500m") # 500 milliCPU (0.5 cores)
flyte.Resources(cpu="2") # 2 CPU cores
flyte.Resources(cpu="1.5") # 1.5 CPU cores
# Numeric formats
flyte.Resources(cpu=1) # 1 CPU core
flyte.Resources(cpu=0.5) # 0.5 CPU cores
# Request and limit ranges
flyte.Resources(cpu=("1", "2")) # Request 1 core, limit to 2 cores
flyte.Resources(cpu=(1, 4)) # Request 1 core, limit to 4 cores
```
### Memory resources
Memory specifications follow Kubernetes conventions:
```python
# Standard memory units
flyte.Resources(memory="512Mi") # 512 MiB
flyte.Resources(memory="1Gi") # 1 GiB
flyte.Resources(memory="2Gi") # 2 GiB
flyte.Resources(memory="500M") # 500 MB (decimal)
flyte.Resources(memory="1G") # 1 GB (decimal)
# Request and limit ranges
flyte.Resources(memory=("1Gi", "4Gi")) # Request 1 GiB, limit to 4 GiB
```
### GPU resources
Flyte supports various GPU types and configurations:
#### Simple GPU allocation
```python
# Basic GPU count
flyte.Resources(gpu=1) # 1 GPU (any available type)
flyte.Resources(gpu=4) # 4 GPUs
# Specific GPU types with quantity
flyte.Resources(gpu="T4:1") # 1 NVIDIA T4 GPU
flyte.Resources(gpu="A100:2") # 2 NVIDIA A100 GPUs
flyte.Resources(gpu="H100:8") # 8 NVIDIA H100 GPUs
```
#### Advanced GPU configuration
You can also use the `GPU` helper class for more detailed configurations:
```python
# Using the GPU helper function
gpu_config = flyte.GPU(device="A100", quantity=2)
flyte.Resources(gpu=gpu_config)
# GPU with memory partitioning (A100 only)
partitioned_gpu = flyte.GPU(
device="A100",
quantity=1,
partition="1g.5gb" # 1/7th of A100 with 5GB memory
)
flyte.Resources(gpu=partitioned_gpu)
# A100 80GB with partitioning
large_partition = flyte.GPU(
device="A100 80G",
quantity=1,
partition="7g.80gb" # Full A100 80GB
)
flyte.Resources(gpu=large_partition)
```
#### Supported GPU types
- **T4**: Entry-level training and inference
- **L4**: Optimized for AI inference
- **L40s**: High-performance compute
- **A100**: High-end training and inference (40GB)
- **A100 80G**: High-end training with more memory (80GB)
- **H100**: Latest generation, highest performance
### Custom device specifications
You can also define custom devices if your infrastructure supports them:
```python
# Custom device configuration
custom_device = flyte.Device(
device="custom_accelerator",
quantity=2,
partition="large"
)
resources = flyte.Resources(gpu=custom_device)
```
### TPU resources
For Google Cloud TPU workloads you can specify TPU resources using the `TPU` helper class:
```python
# TPU v5p configuration
tpu_config = flyte.TPU(device="V5P", partition="2x2x1")
flyte.Resources(gpu=tpu_config) # Note: TPUs use the gpu parameter
# TPU v6e configuration
tpu_v6e = flyte.TPU(device="V6E", partition="4x4")
flyte.Resources(gpu=tpu_v6e)
```
### Storage resources
Flyte provides two types of storage resources for tasks: ephemeral disk storage and shared memory.
These resources are essential for tasks that need temporary storage for processing data, caching intermediate results, or sharing data between processes.
#### Disk storage
Ephemeral disk storage provides temporary space for your tasks to store intermediate files, downloaded datasets, model checkpoints, and other temporary data. This storage is automatically cleaned up when the task completes.
```python
flyte.Resources(disk="10Gi") # 10 GiB ephemeral storage
flyte.Resources(disk="100Gi") # 100 GiB ephemeral storage
flyte.Resources(disk="1Ti") # 1 TiB for large-scale data processing
# Common use cases
flyte.Resources(disk="50Gi") # ML model training with checkpoints
flyte.Resources(disk="200Gi") # Large dataset preprocessing
flyte.Resources(disk="500Gi") # Video/image processing workflows
```
#### Shared memory
Shared memory (`/dev/shm`) is a high-performance, RAM-based storage area that can be shared between processes within the same container. It's particularly useful for machine learning workflows that need fast data loading and inter-process communication.
```python
flyte.Resources(shm="1Gi") # 1 GiB shared memory (/dev/shm)
flyte.Resources(shm="auto") # Auto-sized shared memory
flyte.Resources(shm="16Gi") # Large shared memory for distributed training
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/secrets ===
# Secrets
Flyte secrets enable you to securely store and manage sensitive information, such as API keys, passwords, and other credentials.
Secrets reside in a secret store on the data plane of your Union/Flyte backend.
You can create, list, and delete secrets in the store using the Flyte CLI or SDK.
Secrets in the store can be accessed and used within your workflow tasks, without exposing any cleartext values in your code.
## Creating a literal string secret
You can create a secret using the [`flyte create secret`](../../api-reference/flyte-cli#flyte-create-secret) command like this:
```bash
flyte create secret MY_SECRET_KEY my_secret_value
```
This will create a secret called `MY_SECRET_KEY` with the value `my_secret_value`.
This secret will be scoped to your entire organization.
It will be available across all projects and domains in your organization.
See the **Configure tasks > Secrets > Scoping secrets** section below for more details.
See **Configure tasks > Secrets > Using a literal string secret** for how to access the secret in your task code.
## Creating a file secret
You can also create a secret by specifying a local file:
```bash
flyte create secret MY_SECRET_KEY --from-file /local/path/to/my_secret_file
```
In this case, when accessing the secret in your task code, you will need to **Configure tasks > Secrets > Using a file secret**.
## Scoping secrets
When you create a secret without specifying a project or domain, as we did above, the secret is scoped to the organization level.
This means that the secret will be available across all projects and domains in the organization.
You can optionally specify either or both of the `--project` and `--domain` flags to restrict the scope of the secret to:
* A specific project (across all domains)
* A specific domain (across all project)
* A specific project and a specific domain.
For example, to create a secret that it is only available in `my_project/development`, you would execute the following command:
```bash
flyte create secret --project my_project --domain development MY_SECRET_KEY my_secret_value
```
## Listing secrets
You can list existing secrets with the [`flyte get secret`](../../api-reference/flyte-cli#flyte-get-secret) command.
For example, the following command will list all secrets in the organization:
```bash
flyte get secret
```
Specifying either or both of the `--project` and `--domain` flags will list the secrets that are **only** available in that project and/or domain.
For example, to list the secrets that are only available in `my_project` and domain `development`, you would run:
```bash
flyte get secret --project my_project --domain development
```
## Deleting secrets
To delete a secret, use the [`flyte delete secret`](../../api-reference/flyte-cli#flyte-delete-secret) command:
```bash
flyte delete secret MY_SECRET_KEY
```
## Using a literal string secret
To use a literal string secret, specify it in the `TaskEnvironment` along with the name of the environment variable into which it will be injected.
You can then access it using `os.getenv()` in your task code.
For example:
```
env_1 = flyte.TaskEnvironment(
name="env_1",
secrets=[
flyte.Secret(key="my_secret", as_env_var="MY_SECRET_ENV_VAR"),
]
)
@env_1.task
def task_1():
my_secret_value = os.getenv("MY_SECRET_ENV_VAR")
print(f"My secret value is: {my_secret_value}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py*
## Using a file secret
To use a file secret, specify it in the `TaskEnvironment` along with the `mount="/etc/flyte/secrets"` argument (with that precise value).
The file will be mounted at `/etc/flyte/secrets/`.
For example:
```
env_2 = flyte.TaskEnvironment(
name="env_2",
secrets=[
flyte.Secret(key="my_secret", mount="/etc/flyte/secrets"),
]
)
@env_2.task
def task_2():
with open("/etc/flyte/secrets/my_secret", "r") as f:
my_secret_file_content = f.read()
print(f"My secret file content is: {my_secret_file_content}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py*
> [!NOTE]
> Currently, to access a file secret you must specify a `mount` parameter value of `"/etc/flyte/secrets"`.
> This fixed path is the directory in which the secret file will be placed.
> The name of the secret file will be equal to the key of the secret.
> [!NOTE]
> A `TaskEnvironment` can only access a secret if the scope of the secret includes the project and domain where the `TaskEnvironment` is deployed.
> [!WARNING]
> Do not return secret values from tasks, as this will expose secrets to the control plane.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/caching ===
# Caching
Flyte 2 provides intelligent **task output caching** that automatically avoids redundant computation by reusing previously computed task results.
> [!NOTE]
> Caching works at the task level and caches complete task outputs.
> For function-level checkpointing and resumption *within tasks*, see [Traces](../task-programming/traces).
## Overview
By default, caching is disabled.
If caching is enabled for a task, then Flyte determines a **cache key** for the task.
The key is composed of the following:
* Final inputs: The set of inputs after removing any specified in the `ignored_inputs`.
* Task name: The fully-qualified name of the task.
* Interface hash: A hash of the task's input and output types.
* Cache version: The cache version string.
If the cache behavior is set to `"auto"`, the cache version is automatically generated using a hash of the task's source code (or according to the custom policy if one is specified).
If the cache behavior is set to `"override"`, the cache version can be specified explicitly using the `version_override` parameter.
When the task runs, Flyte checks if a cache entry exists for the key.
If found, the cached result is returned immediately instead of re-executing the task.
## Basic caching usage
Flyte 2 supports three main cache behaviors:
### `"auto"` - Automatic versioning
```
@env.task(cache=flyte.Cache(behavior="auto"))
async def auto_versioned_task(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
With `behavior="auto"`, the cache version is automatically generated based on the function's source code.
If you change the function implementation, the cache is automatically invalidated.
- **When to use**: Development and most production scenarios.
- **Cache invalidation**: Automatic when function code changes.
- **Benefits**: Zero-maintenance caching that "just works".
You can also use the direct string shorthand:
```
@env.task(cache="auto")
async def auto_versioned_task_2(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
### `"override"`
With `behavior="override"`, you can specify a custom cache key in the `version_override` parameter.
Since the cache key is fixed as part of the code, it can be manually changed when you need to invalidate the cache.
```
@env.task(cache=flyte.Cache(behavior="override", version_override="v1.2"))
async def manually_versioned_task(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
- **When to use**: When you need explicit control over cache invalidation.
- **Cache invalidation**: Manual, by changing `version_override`.
- **Benefits**: Stable caching across code changes that don't affect logic.
### `"disable"` - No caching
To explicitly disable caching, use the `"disable"` behavior.
**This is the default behavior.**
```
@env.task(cache=flyte.Cache(behavior="disable"))
async def always_fresh_task(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
- **When to use**: Non-deterministic functions, side effects, or always-fresh data.
- **Cache invalidation**: N/A - never cached.
- **Benefits**: Ensures execution every time.
You can also use the direct string shorthand:
```
@env.task(cache="disable")
async def always_fresh_task_2(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
## Advanced caching configuration
### Ignoring specific inputs
Sometimes you want to cache based on some inputs but not others:
```
@env.task(cache=flyte.Cache(behavior="auto", ignored_inputs=("debug_flag",)))
async def selective_caching(data: str, debug_flag: bool) -> str:
if debug_flag:
print(f"Debug: transforming {data}")
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
**This is useful for**:
- Debug flags that don't affect computation
- Logging levels or output formats
- Metadata that doesn't impact results
### Cache serialization
Cache serialization ensures that only one instance of a task runs at a time for identical inputs:
```
@env.task(cache=flyte.Cache(behavior="auto", serialize=True))
async def expensive_model_training(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
**When to use serialization**:
- Very expensive computations (model training, large data processing)
- Shared resources that shouldn't be accessed concurrently
- Operations where multiple parallel executions provide no benefit
**How it works**:
1. First execution acquires a reservation and runs normally.
2. Concurrent executions with identical inputs wait for the first to complete.
3. Once complete, all waiting executions receive the cached result.
4. If the running execution fails, another waiting execution takes over.
### Salt for cache key variation
Use `salt` to vary cache keys without changing function logic:
```
@env.task(cache=flyte.Cache(behavior="auto", salt="experiment_2024_q4"))
async def experimental_analysis(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
**`salt` is useful for**:
- A/B testing with identical code.
- Temporary cache namespaces for experiments.
- Environment-specific cache isolation.
## Cache policies
For details on implementing custom cache policies, see the [`CachePolicy` protocol](../../api-reference/flyte-sdk/packages/flyte/cachepolicy) and [`Cache` class](../../api-reference/flyte-sdk/packages/flyte/cache) API references.
For `behavior="auto"`, Flyte uses cache policies to generate version hashes.
### Function body policy (default)
The default `FunctionBodyPolicy` generates cache versions from the function's source code:
```
from flyte._cache import FunctionBodyPolicy
@env.task(cache=flyte.Cache(
behavior="auto",
policies=[FunctionBodyPolicy()] # This is the default. Does not actually need to be specified.
))
async def code_sensitive_task(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
### Custom cache policies
You can implement custom cache policies by following the `CachePolicy` protocol:
```
from flyte._cache import CachePolicy
class DatasetVersionPolicy(CachePolicy):
def get_version(self, salt: str, params) -> str:
# Generate version based on custom logic
dataset_version = get_dataset_version()
return f"{salt}_{dataset_version}"
@env.task(cache=flyte.Cache(behavior="auto", policies=[DatasetVersionPolicy()]))
async def dataset_dependent_task(data: str) -> str:
# Cache invalidated when dataset version changes
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
## Caching configuration at different levels
You can configure caching at three levels: `TaskEnvironment` definition, `@env.task` decorator, and task invocation.
### `TaskEnvironment` Level
You can configure caching at the `TaskEnvironment` level.
This will set the default cache behavior for all tasks defined using that environment.
For example:
```
cached_env = flyte.TaskEnvironment(
name="cached_environment",
cache=flyte.Cache(behavior="auto") # Default for all tasks
)
@cached_env.task # Inherits auto caching from environment
async def inherits_caching(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
### `@env.task` decorator level
By setting the cache parameter in the `@env.task` decorator, you can override the environment's default cache behavior for specific tasks:
```
@cached_env.task(cache=flyte.Cache(behavior="disable")) # Override environment default
async def decorator_caching(data: str) -> str:
return await transform_data(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
### `task.override` level
By setting the cache parameter in the `task.override` method, you can override the cache behavior for specific task invocations:
```
@env.task
async def override_caching_on_call(data: str) -> str:
# Create an overridden version and call it
overridden_task = inherits_caching.override(cache=flyte.Cache(behavior="disable"))
return await overridden_task(data)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py*
## Runtime cache control
You can also force cache invalidation for a specific run:
```python
# Disable caching for this specific execution
run = flyte.with_runcontext(overwrite_cache=True).run(my_cached_task, data="test")
```
## Project and domain cache isolation
Caches are automatically isolated by:
- **Project**: Tasks in different projects have separate cache namespaces.
- **Domain**: Development, staging, and production domains maintain separate caches.
## Local development caching
When running locally, Flyte maintains a local cache:
```python
# Local execution uses ~/.flyte/local-cache/
flyte.init() # Local mode
result = flyte.run(my_cached_task, data="test")
```
Local cache behavior:
- Stored in `~/.flyte/local-cache/` directory
- No project/domain isolation (since running locally)
- Disabled by setting `FLYTE_LOCAL_CACHE_ENABLED=false`
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/reusable-containers ===
# Reusable containers
By default, each task execution in Flyte and Union runs in a fresh container instance that is created just for that execution and then discarded.
With reusable containers, the same container can be reused across multiple executions and tasks.
This approach reduces start up overhead and improves resource efficiency.
> [!NOTE]
> The reusable container feature is only available when running your Flyte code on a Union backend.
> See [one of the Union.ai product variants of this page](/docs/v2/byoc//user-guide/reusable-containers) for details.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/pod-templates ===
# Pod templates
Flyte is built on Kubernetes and leverages its powerful container orchestration capabilities. A Kubernetes [pod](https://kubernetes.io/docs/concepts/workloads/pods/) is a group of one or more containers that share storage and network resources. While Flyte automatically runs your task code in a container, pod templates let you customize the entire pod specification for advanced use cases.
The `pod_template` parameter in `TaskEnvironment` allows you to:
- **Add sidecar containers**: Run metrics exporters, service proxies, or specialized services alongside your task
- **Mount volumes**: Attach persistent storage or cloud storage like GCS or S3
- **Configure metadata**: Set custom labels and annotations for monitoring, routing, or cluster policies
- **Manage resources**: Configure resource requests, limits, and affinities
- **Inject configuration**: Add secrets, environment variables, or config maps
- **Access private registries**: Specify image pull secrets
## How it works
When you define a pod template:
1. **Primary container**: Flyte automatically injects your task code into the container specified by `primary_container_name` (default: `"primary"`)
2. **Automatic monitoring**: Flyte watches the primary container and exits the entire pod when it completes
3. **Image handling**: The image for your task environment is built automatically by Flyte; images for sidecar containers must be provided by you
4. **Local execution**: When running locally, only the task code executesβadditional containers are not started
## Requirements
To use pod templates, install the Kubernetes Python client:
```bash
pip install kubernetes
```
Or add it to your image dependencies:
```python
image = flyte.Image.from_debian_base().with_pip_packages("kubernetes")
```
## Basic usage
Here's a complete example showing how to configure labels, annotations, environment variables, and image pull secrets:
```
# /// script
# requires-python = "==3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "kubernetes"
# ]
# ///
import flyte
from kubernetes.client import (
V1Container,
V1EnvVar,
V1LocalObjectReference,
V1PodSpec,
)
# Create a custom pod template
pod_template = flyte.PodTemplate(
primary_container_name="primary", # Name of the main container
labels={"lKeyA": "lValA"}, # Custom pod labels
annotations={"aKeyA": "aValA"}, # Custom pod annotations
pod_spec=V1PodSpec( # Kubernetes pod specification
containers=[
V1Container(
name="primary",
env=[V1EnvVar(name="hello", value="world")] # Environment variables
)
],
image_pull_secrets=[ # Access to private registries
V1LocalObjectReference(name="regcred-test")
],
),
)
# Use the pod template in a TaskEnvironment
env = flyte.TaskEnvironment(
name="hello_world",
pod_template=pod_template, # Apply the custom pod template
image=flyte.Image.from_uv_script(__file__, name="flyte", pre=True),
)
@env.task
async def say_hello(data: str) -> str:
return f"Hello {data}"
@env.task
async def say_hello_nested(data: str = "default string") -> str:
return await say_hello(data=data)
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(say_hello_nested, data="hello world")
print(result.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/pod-templates/pod_template.py*
## PodTemplate parameters
The `PodTemplate` class provides the following parameters:
| Parameter | Type | Description |
|-----------|------|-------------|
| `primary_container_name` | `str` | Name of the container where task code runs (default: `"primary"`). Must match a container in `pod_spec`. |
| `pod_spec` | `V1PodSpec` | Kubernetes pod specification for configuring containers, volumes, security contexts, and more. |
| `labels` | `dict[str, str]` | Pod labels for organization and selection by Kubernetes selectors. |
| `annotations` | `dict[str, str]` | Pod annotations for metadata and integrations (doesn't affect scheduling). |
## Volume mounts
Pod templates are commonly used to mount volumes for persistent storage or cloud storage access:
```python
from kubernetes.client import (
V1Container,
V1PodSpec,
V1Volume,
V1VolumeMount,
V1CSIVolumeSource,
)
import flyte
pod_template = flyte.PodTemplate(
primary_container_name="primary",
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary",
volume_mounts=[
V1VolumeMount(
name="data-volume",
mount_path="/mnt/data",
read_only=False,
)
],
)
],
volumes=[
V1Volume(
name="data-volume",
csi=V1CSIVolumeSource(
driver="your-csi-driver",
volume_attributes={"key": "value"},
),
)
],
),
)
env = flyte.TaskEnvironment(
name="volume-example",
pod_template=pod_template,
image=flyte.Image.from_debian_base(),
)
@env.task
async def process_data() -> str:
# Access mounted volume
with open("/mnt/data/input.txt", "r") as f:
data = f.read()
return f"Processed {len(data)} bytes"
```
### GCS/S3 volume mounts
Mount cloud storage directly into your pod for efficient data access:
```python
from kubernetes.client import V1Container, V1PodSpec, V1Volume, V1VolumeMount, V1CSIVolumeSource
import flyte
# GCS example with CSI driver
pod_template = flyte.PodTemplate(
primary_container_name="primary",
annotations={
"gke-gcsfuse/volumes": "true",
"gke-gcsfuse/cpu-limit": "2",
"gke-gcsfuse/memory-limit": "1Gi",
},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary",
volume_mounts=[V1VolumeMount(name="gcs", mount_path="/mnt/gcs")],
)
],
volumes=[
V1Volume(
name="gcs",
csi=V1CSIVolumeSource(
driver="gcsfuse.csi.storage.gke.io",
volume_attributes={"bucketName": "my-bucket"},
),
)
],
),
)
```
## Sidecar containers
Add sidecar containers to run alongside your task. Common use cases include:
- **Metrics exporters**: Prometheus, Datadog agents
- **Service proxies**: Istio, Linkerd sidecars
- **Data services**: Databases, caches, or specialized services like Nvidia NIMs
```python
from kubernetes.client import V1Container, V1PodSpec
import flyte
pod_template = flyte.PodTemplate(
primary_container_name="primary",
pod_spec=V1PodSpec(
containers=[
# Primary container (where your task code runs)
V1Container(name="primary"),
# Sidecar container
V1Container(
name="metrics-sidecar",
image="prom/pushgateway:latest",
ports=[{"containerPort": 9091}],
),
],
),
)
env = flyte.TaskEnvironment(
name="sidecar-example",
pod_template=pod_template,
image=flyte.Image.from_debian_base().with_pip_packages("requests"),
)
@env.task
async def task_with_metrics() -> str:
import requests
# Send metrics to sidecar
requests.post("http://localhost:9091/metrics", data="my_metric 42")
# Your task logic
return "Task completed with metrics"
```
## Image pull secrets
Configure private registry access:
```python
from kubernetes.client import V1Container, V1PodSpec, V1LocalObjectReference
import flyte
pod_template = flyte.PodTemplate(
primary_container_name="primary",
pod_spec=V1PodSpec(
containers=[V1Container(name="primary")],
image_pull_secrets=[V1LocalObjectReference(name="my-registry-secret")],
),
)
```
## Cluster-specific configuration
Pod templates are often used to configure Kubernetes-specific settings required by your cluster, even when not using multiple containers:
```python
import flyte
pod_template = flyte.PodTemplate(
primary_container_name="primary",
annotations={
"iam.amazonaws.com/role": "my-task-role", # AWS IAM role
"cluster-autoscaler.kubernetes.io/safe-to-evict": "false",
},
labels={
"cost-center": "ml-team",
"project": "recommendations",
},
)
```
## Important notes
1. **Local execution**: Pod templates only apply to remote execution. When running locally, only your task code executes.
2. **Image building**: Flyte automatically builds and manages the image for your task environment. Images for sidecar containers must be pre-built and available in a registry.
3. **Primary container**: Your task code is automatically injected into the container matching `primary_container_name`. This container must be defined in the `pod_spec.containers` list.
4. **Lifecycle management**: Flyte monitors the primary container and terminates the entire pod when it exits, ensuring sidecar containers don't run indefinitely.
## Best practices
1. **Start simple**: Begin with basic labels and annotations before adding complex sidecars
2. **Test locally first**: Verify your task logic works locally before adding pod customizations
3. **Use environment-specific templates**: Different environments (dev, staging, prod) may need different pod configurations
4. **Set resource limits**: Always set resource requests and limits for sidecars to prevent cluster issues
5. **Security**: Use image pull secrets and least-privilege service accounts
## Learn more
- [Kubernetes Pods Documentation](https://kubernetes.io/docs/concepts/workloads/pods/)
- [Kubernetes Python Client](https://github.com/kubernetes-client/python)
- [V1PodSpec Reference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#podspec-v1-core)
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/multiple-environments ===
# Multiple environments
In many applications, different tasks within your workflow may require different configurations.
Flyte enables you to manage this complexity by allowing multiple environments within a single workflow.
Multiple environments are useful when:
- Different tasks in your workflow need different dependencies.
- Some tasks require specific CPU/GPU or memory configurations.
- A task requires a secret that other tasks do not (and you want to limit exposure of the secret value).
- You're integrating specialized tools that have conflicting requirements.
## Constraints on multiple environments
To use multiple environments in your workflow you define multiple `TaskEnvironment` instances, each with its own configuration, and then assign tasks to their respective environments.
There are, however, two additional constraints that you must take into account.
If `task_1` in environment `env_1` calls a `task_2` in environment `env_2`, then:
1. `env_1` must declare a deployment-time dependency on `env_2` in the `depends_on` parameter of `TaskEnvironment` that defines `env_1`.
2. The image used in the `TaskEnvironment` of `env_1` must include all dependencies of the module containing the `task_2` (unless `task_2` is invoked as a remote task).
### Task `depends_on` constraints
The `depends_on` parameter in `TaskEnvironment` is used to provide deployment-time dependencies by establishing a relationship between one `TaskEnvironment` and another.
The system uses this information to determine which environments (and, specifically which images) need to be built in order to be able to run the code.
On `flyte run` (or `flyte deploy`), the system walks the tree defined by the `depends_on` relationships, starting with the environment of the task being invoked (or the environment being deployed, in the case of `flyte deploy`), and prepares each required environment.
Most importantly, it ensures that the container images need for all required environments are available (and if not, it builds them).
This deploy-time determination of what to build is important because it means that for any given `run` or `deploy`, only those environments that are actually required are built.
The alternative strategy of building all environments defined in the set of deployed code can lead to unnecessary and expensive builds, especially when iterating on code.
### Dependency inclusion constraints
When a parent task invokes a child task in a different environment, the container image of the parent task environment must include all dependencies used by the child task.
This is necessary because of the way task invocation works in Flyte:
- When a child task is invoked by function name, that function, necessarily, has to be imported into the parent tasks's Python environment.
- This results in all the dependencies of the child task function also being imported.
- But, nonetheless, the actual execution of the child task occurs in its own environment.
To avoid this requirement, you can invoke a task in another environment _remotely_.
## Example
The following example is a (very) simple mock of an AlphaFold2 pipeline.
It demonstrates a workflow with three tasks, each in its own environment.
The example project looks like this:
```bash
βββ msa/
β βββ __init__.py
β βββ run.py
βββ fold/
β βββ __init__.py
β βββ run.py
βββ __init__.py
βββ main.py
```
(The source code for this example can be found here:[AlphaFold2 mock example](https://github.com/unionai/unionai-examples/tree/main/v2/user-guide/task-configuration/multiple-environments/af2))
In file `msa/run.py` we define the task `run_msa`, which mocks the multiple sequence alignment step of the process:
```python
import flyte
from flyte.io import File
MSA_PACKAGES = ["pytest"]
msa_image = flyte.Image.from_debian_base().with_pip_packages(*MSA_PACKAGES)
msa_env = flyte.TaskEnvironment(name="msa_env", image=msa_image)
@msa_env.task
def run_msa(x: str) -> File:
f = File.new_remote()
with f.open_sync("w") as fp:
fp.write(x)
return f
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/msa/run.py*
* A dedicated image (`msa_image`) is built using the `MSA_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`msa_env`) is defined for the task, using `msa_image`.
* The task is defined within the context of the `msa_env` environment.
In file `fold/run.py` we define the task `run_fold`, which mocks the fold step of the process:
```python
import flyte
from flyte.io import File
FOLD_PACKAGES = ["ruff"]
fold_image = flyte.Image.from_debian_base().with_pip_packages(*FOLD_PACKAGES)
fold_env = flyte.TaskEnvironment(name="fold_env", image=fold_image)
@fold_env.task
def run_fold(sequence: str, msa: File) -> list[str]:
with msa.open_sync("r") as f:
msa_content = f.read()
return [msa_content, sequence]
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/fold/run.py*
* A dedicated image (`fold_image`) is built using the `FOLD_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`fold_env`) is defined for the task, using `fold_image`.
* The task is defined within the context of the `fold_env` environment.
Finally, in file `main.py` we define the task `main` that ties everything together into a workflow.
We import the required modules and functions:
```
import logging
import pathlib
from fold.run import fold_env, fold_image, run_fold
from msa.run import msa_env, MSA_PACKAGES, run_msa
import flyte
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py*
Notice that we import
* The task functions that we will be calling: `run_fold` and `run_msa`.
* The environments of those tasks: `fold_env` and `msa_env`.
* The dependency list of the `run_msa` task: `MSA_PACKAGES`
* The image of the `run_fold` task: `fold_image`
We then assemble the image and the environment:
```
main_image = fold_image.with_pip_packages(*MSA_PACKAGES)
env = flyte.TaskEnvironment(
name="multi_env",
depends_on=[fold_env, msa_env],
image=main_image,
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py*
The image for the `main` task (`main_image`) is built by starting with `fold_image` (the image for the `run_fold` task) and adding `MSA_PACKAGES` (the dependency list for the `run_msa` task).
This ensures that `main_image` includes all dependencies needed by both the `run_fold` and `run_msa` tasks.
The environment for the `main` task is defined with:
* The image `main_image`. This ensures that the `main` task has all the dependencies it needs.
* A depends_on list that includes both `fold_env` and `msa_env`. This establishes the deploy-time dependencies on those environments.
Finally, we define the `main` task itself:
```
@env.task
def main(sequence: str) -> list[str]:
"""Given a sequence, outputs files containing the protein structure
This requires model weights + gpus + large database on aws fsx lustre
"""
print(f"Running AlphaFold2 for sequence: {sequence}")
msa = run_msa(sequence)
print(f"MSA result: {msa}, passing to fold task")
results = run_fold(sequence, msa)
print(f"Fold results: {results}")
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py*
Here we call, in turn, the `run_msa` and `run_fold` tasks.
Since we call them directly rather than as remote tasks, we had to ensure that `main_image` includes all dependencies needed by both tasks.
The final piece of the puzzle is the `if __name__ == "__main__":` block that allows us to run the `main` task on the configured Flyte backend:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, "AAGGTTCCAA")
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py*
Now you can run the workflow with:
```bash
python main.py
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/retries-and-timeouts ===
# Retries and timeouts
Flyte provides robust error handling through configurable retry strategies and timeout controls.
These parameters help ensure task reliability and prevent resource waste from runaway processes.
## Retries
The `retries` parameter controls how many times a failed task should be retried before giving up.
A "retry" is any attempt after the initial attempt.
In other words, `retries=3` means the task may be attempted up to 4 times in total (1 initial + 3 retries).
The `retries` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the examples below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py).
### Retry example
First we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import flyte
env = flyte.TaskEnvironment(name="my-env")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py*
Then we configure our task to retry up to 3 times if it fails (for a total of 4 attempts). We also define the driver task `main` that calls the `retry` task:
```
@env.task(retries=3)
async def retry() -> str:
if random.random() < 0.7: # 70% failure rate
raise Exception("Task failed!")
return "Success!"
@env.task
async def main() -> list[str]:
results = []
try:
results.append(await retry())
except Exception as e:
results.append(f"Failed: {e}")
try:
results.append(await retry.override(retries=5)())
except Exception as e:
results.append(f"Failed: {e}")
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py*
Note that we call `retry` twice: first without any `override`, and then with an `override` to increase the retries to 5 (for a total of 6 attempts).
Finally, we configure flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py*
## Timeouts
The `timeout` parameter sets limits on how long a task can run, preventing resource waste from stuck processes.
It supports multiple formats for different use cases.
The `timeout` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the example below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py).
### Timeout example
First, we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import asyncio
import flyte
from flyte import Timeout
env = flyte.TaskEnvironment(name="my-env")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
Our first task sets a timeout using seconds as an integer:
```
@env.task(timeout=60) # 60 seconds
async def timeout_seconds() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_seconds completed"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
We can also set a timeout using a `timedelta` object for more readable durations:
```
@env.task(timeout=timedelta(minutes=1))
async def timeout_timedelta() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_timedelta completed"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
You can also set separate timeouts for maximum execution time and maximum queue time using the `Timeout` class:
```
@env.task(timeout=Timeout(
max_runtime=timedelta(minutes=1), # Max execution time per attempt
max_queued_time=timedelta(minutes=1) # Max time in queue before starting
))
async def timeout_advanced() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
You can also combine retries and timeouts for resilience and resource control:
```
@env.task(
retries=3,
timeout=Timeout(
max_runtime=timedelta(minutes=1),
max_queued_time=timedelta(minutes=1)
)
)
async def timeout_with_retry() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
Here we specify:
- Up to 3 retry attempts.
- Each attempt times out after 1 minute.
- Task fails if queued for more than 1 minute.
- Total possible runtime: 1 minute queue + (1 minute Γ 3 attempts).
We define the `main` driver task that calls all the timeout tasks concurrently and returns their outputs as a list. The return value for failed tasks will indicate failure:
```
@env.task
async def main() -> list[str]:
tasks = [
timeout_seconds(),
timeout_seconds.override(timeout=120)(), # Override to 120 seconds
timeout_timedelta(),
timeout_advanced(),
timeout_with_retry(),
]
results = await asyncio.gather(*tasks, return_exceptions=True)
output = []
for r in results:
if isinstance(r, Exception):
output.append(f"Failed: {r}")
else:
output.append(r)
return output
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
Note that we also demonstrate overriding the timeout for `timeout_seconds` to 120 seconds when calling it.
Finally, we configure Flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py*
Proper retry and timeout configuration ensures your Flyte workflows are both reliable and efficient, handling transient failures gracefully while preventing resource waste.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/triggers ===
# Triggers
Triggers allow you to automate and parameterize an execution by scheduling its start time and providing overrides for its task inputs.
Currently, only **schedule triggers** are supported.
This type of trigger runs a task based on a Cron expression or a fixed-rate schedule.
Support is coming for other trigger types, such as:
* Webhook triggers: Hit an API endpoint to run your task.
* Artifact triggers: Run a task when a specific artifact is produced.
## Triggers are set in the task decorator
A trigger is created by setting the `triggers` parameter in the task decorator to a `flyte.Trigger` object or a list of such objects (triggers are not settable at the `TaskEnvironment` definition or `task.override` levels).
Here is a simple example:
```
import flyte
from datetime import datetime, timezone
env = flyte.TaskEnvironment(name="trigger_env")
@env.task(triggers=flyte.Trigger.hourly()) # Every hour
def hourly_task(trigger_time: datetime, x: int = 1) -> str:
return f"Hourly example executed at {trigger_time.isoformat()} with x={x}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
Here we use a predefined schedule trigger to run the `hourly_task` every hour.
Other predefined triggers can be used similarly (see **Configure tasks > Triggers > Predefined schedule triggers** below).
If you want full control over the trigger behavior, you can define a trigger using the `flyte.Trigger` class directly.
## `flyte.Trigger`
For complete parameter documentation, see the [`Trigger`](../../api-reference/flyte-sdk/packages/flyte/trigger), [`Cron`](../../api-reference/flyte-sdk/packages/flyte/cron), and [`FixedRate`](../../api-reference/flyte-sdk/packages/flyte/fixedrate) API references.
The `Trigger` class allows you to define custom triggers with full control over scheduling and execution behavior. It has the following signature:
```
flyte.Trigger(
name,
automation,
description="",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Core Parameters
**`name: str`** (required)
The unique identifier for the trigger within your project/domain.
**`automation: Union[Cron, FixedRate]`** (required)
Defines when the trigger fires. Use `flyte.Cron("expression")` for Cron-based scheduling or `flyte.FixedRate(interval_minutes, start_time=start_time)` for fixed intervals.
### Configuration Parameters
**`description: str = ""`**
Human-readable description of the trigger's purpose.
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Default parameter values for the task when triggered. Use `flyte.TriggerTime` as a value to inject the trigger execution timestamp into that parameter.
### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
Here's a comprehensive example showing all parameters:
```
comprehensive_trigger = flyte.Trigger(
name="monthly_financial_report",
automation=flyte.Cron("0 6 1 * *", timezone="America/New_York"),
description="Monthly financial report generation for executive team",
auto_activate=True,
inputs={
"report_date": flyte.TriggerTime,
"report_type": "executive_summary",
"include_forecasts": True
},
env_vars={
"REPORT_OUTPUT_FORMAT": "PDF",
"EMAIL_NOTIFICATIONS": "true"
},
interruptible=False, # Critical report, use dedicated resources
overwrite_cache=True, # Always fresh data
queue="financial-reports",
labels={
"team": "finance",
"criticality": "high",
"automation": "scheduled"
},
annotations={
"compliance.company.com/sox-required": "true",
"backup.company.com/retain-days": "2555" # 7 years
}
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
## The `automation` parameter with `flyte.FixedRate`
You can define a fixed-rate schedule trigger by setting the `automation` parameter of the `flyte.Trigger` to an instance of `flyte.FixedRate`.
The `flyte.FixedRate` has the following signature:
```
flyte.FixedRate(
interval_minutes,
start_time=None
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Parameters
**`interval_minutes: int`** (required)
The interval between trigger executions in minutes.
**`start_time: datetime | None`**
When to start the fixed rate schedule. If not specified, starts when the trigger is deployed and activated.
### Examples
```
# Every 90 minutes, starting when deployed
every_90_min = flyte.Trigger(
"data_processing",
flyte.FixedRate(interval_minutes=90)
)
# Every 6 hours (360 minutes), starting at a specific time
specific_start = flyte.Trigger(
"batch_job",
flyte.FixedRate(
interval_minutes=360, # 6 hours
start_time=datetime(2025, 12, 1, 9, 0, 0) # Start Dec 1st at 9 AM
)
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
## The `automation` parameter with `flyte.Cron`
You can define a Cron-based schedule trigger by setting the `automation` parameter to an instance of `flyte.Cron`.
The `flyte.Cron` has the following signature:
```
flyte.Cron(
cron_expression,
timezone=None
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Parameters
**`cron_expression: str`** (required)
The cron expression defining when the trigger should fire. Uses standard Unix cron format with five fields: minute, hour, day of month, month, and day of week.
**`timezone: str | None`**
The timezone for the cron expression. If not specified, it defaults to UTC. Uses standard timezone names like "America/New_York" or "Europe/London".
### Examples
```
# Every day at 6 AM UTC
daily_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 6 * * *")
)
# Every weekday at 9:30 AM Eastern Time
weekday_trigger = flyte.Trigger(
"business_hours_task",
flyte.Cron("30 9 * * 1-5", timezone="America/New_York")
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
#### Cron Expressions
Here are some common cron expressions you can use:
| Expression | Description |
|----------------|--------------------------------------|
| `0 0 * * *` | Every day at midnight |
| `0 9 * * 1-5` | Every weekday at 9 AM |
| `30 14 * * 6` | Every Saturday at 2:30 PM |
| `0 0 1 * *` | First day of every month at midnight |
| `0 0 25 * *` | 25th day of every month at midnight |
| `0 0 * * 0` | Every Sunday at midnight |
| `*/10 * * * *` | Every 10 minutes |
| `0 */2 * * *` | Every 2 hours |
For a full guide on Cron syntax, refer to [Crontab Guru](https://crontab.guru/).
## The `inputs` parameter
The `inputs` parameter allows you to provide default values for your task's parameters when the trigger fires.
This is essential for parameterizing your automated executions and passing trigger-specific data to your tasks.
### Basic Usage
```
trigger_with_inputs = flyte.Trigger(
"data_processing",
flyte.Cron("0 6 * * *"), # Daily at 6 AM
inputs={
"batch_size": 1000,
"environment": "production",
"debug_mode": False
}
)
@env.task(triggers=trigger_with_inputs)
def process_data(batch_size: int, environment: str, debug_mode: bool = True) -> str:
return f"Processing {batch_size} items in {environment} mode"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Using `flyte.TriggerTime`
The special `flyte.TriggerTime` value is used in the `inputs` to indicate the task parameter into which Flyte will inject the trigger execution timestamp:
```
timestamp_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 0 * * *"), # Daily at midnight
inputs={
"report_date": flyte.TriggerTime, # Receives trigger execution time
"report_type": "daily_summary"
}
)
@env.task(triggers=timestamp_trigger)
def generate_report(report_date: datetime, report_type: str) -> str:
return f"Generated {report_type} for {report_date.strftime('%Y-%m-%d')}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Required vs optional parameters
> [!IMPORTANT]
> If your task has parameters without default values, you **must** provide values for them in the trigger inputs, otherwise the trigger will fail to execute.
```python
# β This will fail - missing required parameter 'data_source'
bad_trigger = flyte.Trigger(
"bad_trigger",
flyte.Cron("0 0 * * *")
# Missing inputs for required parameter 'data_source'
)
@env.task(triggers=bad_trigger)
def bad_trigger_taska(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
# β This works - all required parameters provided
good_trigger = flyte.Trigger(
"good_trigger",
flyte.Cron("0 0 * * *"),
inputs={
"data_source": "prod_database", # Required parameter
"batch_size": 500 # Override default
}
)
@env.task(triggers=good_trigger)
def good_trigger_task(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
```
### Complex input types
You can pass various data types through trigger inputs:
```
complex_trigger = flyte.Trigger(
"ml_training",
flyte.Cron("0 2 * * 1"), # Weekly on Monday at 2 AM
inputs={
"model_config": {
"learning_rate": 0.01,
"batch_size": 32,
"epochs": 100
},
"feature_columns": ["age", "income", "location"],
"validation_split": 0.2,
"training_date": flyte.TriggerTime
}
)
@env.task(triggers=complex_trigger)
def train_model(
model_config: dict,
feature_columns: list[str],
validation_split: float,
training_date: datetime
) -> str:
return f"Training model with {len(feature_columns)} features on {training_date}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
## Predefined schedule triggers
For common scheduling needs, Flyte provides predefined trigger methods that create Cron-based schedules without requiring you to specify cron expressions manually.
These are convenient shortcuts for frequently used scheduling patterns.
### Available Predefined Triggers
```
minutely_trigger = flyte.Trigger.minutely() # Every minute
hourly_trigger = flyte.Trigger.hourly() # Every hour
daily_trigger = flyte.Trigger.daily() # Every day at midnight
weekly_trigger = flyte.Trigger.weekly() # Every week (Sundays at midnight)
monthly_trigger = flyte.Trigger.monthly() # Every month (1st day at midnight)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
For reference, here's what each predefined trigger is equivalent to:
```python
# These are functionally identical:
flyte.Trigger.minutely() == flyte.Trigger("minutely", flyte.Cron("* * * * *"))
flyte.Trigger.hourly() == flyte.Trigger("hourly", flyte.Cron("0 * * * *"))
flyte.Trigger.daily() == flyte.Trigger("daily", flyte.Cron("0 0 * * *"))
flyte.Trigger.weekly() == flyte.Trigger("weekly", flyte.Cron("0 0 * * 0"))
flyte.Trigger.monthly() == flyte.Trigger("monthly", flyte.Cron("0 0 1 * *"))
```
### Predefined Trigger Parameters
All predefined trigger methods (`minutely()`, `hourly()`, `daily()`, `weekly()`, `monthly()`) accept the same set of parameters:
```
flyte.Trigger.daily(
trigger_time_input_key="trigger_time",
name="daily",
description="A trigger that runs daily at midnight",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
#### Core Parameters
**`trigger_time_input_key: str = "trigger_time"`**
The name of the task parameter that will receive the execution timestamp.
If no `trigger_time_input_key` is provided, the default is `trigger_time`.
In this case, if the task does not have a parameter named `trigger_time`, the task will still be executed, but, obviously, the timestamp will not be passed.
However, if you do specify a `trigger_time_input_key`, but your task does not actually have the specified parameter, an error will be raised at trigger deployment time.
**`name: str`**
The unique identifier for the trigger. Defaults to the method name (`"daily"`, `"hourly"`, etc.).
**`description: str`**
Human-readable description of the trigger's purpose. Each method has a sensible default.
#### Configuration Parameters
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Additional parameter values for your task when triggered. The `trigger_time_input_key` parameter is automatically included with `flyte.TriggerTime` as its value.
#### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
#### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
### Trigger time in predefined triggers
By default, predefined triggers will pass the execution time to the parameter `trigger_time` of type `datetime`,if that parameter exists on the task.
If no such parameter exists, the task will still be executed without error.
Optionally, you can customize the parameter name that receives the trigger execution timestamp by setting the `trigger_time_input_key` parameter (in this case the absence of this custom parameter on the task will raise an error at trigger deployment time):
```
@env.task(triggers=flyte.Trigger.daily(trigger_time_input_key="scheduled_at"))
def task_with_custom_trigger_time_input(scheduled_at: datetime) -> str:
return f"Executed at {scheduled_at}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
## Multiple triggers per task
You can attach multiple triggers to a single task by providing a list of triggers. This allows you to run the same task on different schedules or with different configurations:
```
@env.task(triggers=[
flyte.Trigger.hourly(), # Predefined trigger
flyte.Trigger.daily(), # Another predefined trigger
flyte.Trigger("custom", flyte.Cron("0 */6 * * *")) # Custom trigger every 6 hours
])
def multi_trigger_task(trigger_time: datetime = flyte.TriggerTime) -> str:
# Different logic based on execution timing
if trigger_time.hour == 0: # Daily run at midnight
return f"Daily comprehensive processing at {trigger_time}"
else: # Hourly or custom runs
return f"Regular processing at {trigger_time.strftime('%H:%M')}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
You can mix and match trigger types, combining predefined triggers with those that use `flyte.Cron`, and `flyte.FixedRate` automations (see below for explanations of these concepts).
## Deploying a task with triggers
We recommend that you define your triggers in code together with your tasks and deploy them together.
The Union UI displays:
* `Owner` - who last deployed the trigger.
* `Last updated` - who last activated or deactivated the trigger and when. Note: If you deploy a trigger with `auto_activate=True`(default), this will match the `Owner`.
* `Last Run` - when was the last run created by this trigger.
For development and debugging purposes, you can adjust and deploy individual triggers from the UI.
To deploy a task with its triggers, you can either use Flyte CLI:
```bash
flyte deploy -p -d env
```
Or in Python:
```python
flyte.deploy(env)
```
Upon deploy, all triggers that are associated with a given task `T` will be automatically switched to apply to the latest version of that task. Triggers on task `T` which are defined elsewhere (i.e. in the UI) will be deleted unless they have been referenced in the task definition of `T`
## Activating and deactivating triggers
By default, triggers are automatically activated upon deployment (`auto_activate=True`).
Alternatively, you can set `auto_activate=False` to deploy inactive triggers.
An inactive trigger will not create runs until activated.
```
env = flyte.TaskEnvironment(name="my_task_env")
custom_cron_trigger = flyte.Trigger(
"custom_cron",
flyte.Cron("0 0 * * *"),
auto_activate=False # Dont create runs yet
)
@env.task(triggers=custom_cron_trigger)
def custom_task() -> str:
return "Hello, world!"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
This trigger won't create runs until it is explicitly activated.
You can activate a trigger via the Flyte CLI:
```bash
flyte update trigger custom_cron my_task_env.custom_task --activate --project --domain
```
If you want to stop your trigger from creating new runs, you can deactivate it:
```bash
flyte update trigger custom_cron my_task_env.custom_task --deactivate --project --domain
```
You can also view and manage your deployed triggers in the Union UI.
## Trigger run timing
The timing of the first run created by a trigger depends on the type of trigger used (Cron-based or Fixed-rate) and whether the trigger is active upon deployment.
### Cron-based triggers
For Cron-based triggers, the first run will be created at the next scheduled time according to the cron expression after trigger activation and similarly thereafter.
* `0 0 * * *` If deployed at 17:00 today, the trigger will first fire 7 hours later (0:00 of the following day) and then every day at 0:00 thereafter.
* `*/15 14 * * 1-5` if today is Tuesday at 17:00, the trigger will fire the next day (Wednesday) at 14:00, 14:15, 14:30, and 14:45 and then the same for every subsequent weekday thereafter.
### Fixed-rate triggers without `start_time`
If no `start_time` is specified, then the first run will be created after the specified interval from the time of activation. No run will be created immediately upon activation, but the activation time will be used as the reference point for future runs.
#### No `start_time`, auto_activate: True
Let's say you define a fixed rate trigger with automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60))
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
In this case, the first run will occur 60 minutes after the successful deployment of the trigger.
So, if you deployed this trigger at 13:15, the first run will occur at 14:15 and so on thereafter.
#### No `start_time`, auto_activate: False
On the other hand, let's say you define a fixed rate trigger without automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60), auto_activate=False)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
Then you activate it after about 3 hours. In this case the first run will kick off 60 minutes after trigger activation.
If you deployed the trigger at 13:15 and activated it at 16:07, the first run will occur at 17:07.
### Fixed-rate triggers with `start_time`
If a `start_time` is specified, the timing of the first run depends on whether the trigger is active at `start_time` or not.
#### Fixed-rate with `start_time` while active
If a `start_time` is specified, and the trigger is active at `start_time` then the first run will occur at `start_time` and then at the specified interval thereafter.
For example:
```
my_trigger = flyte.Trigger(
"my_trigger",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
If you deploy this trigger on October 24th, 2025, the trigger will wait until October 26th 10:00am and will create the first run at exactly 10:00am.
#### Fixed-rate with `start_time` while inactive
If a start time is specified, but the trigger is activated after `start_time`, then the first run will be created when the next time point occurs that aligns with the recurring trigger interval using `start_time` as the initial reference point.
For example:
```
custom_rate_trigger = flyte.Trigger(
"custom_rate",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
auto_activate=False
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
If activated later than the `start_time`, say on October 28th 12:35pm for example, the first run will be created at October 28th at 1:00pm.
## Deleting triggers
If you decide that you don't need a trigger anymore, you can remove the trigger from the task definition and deploy the task again.
Alternatively, you can use Flyte CLI:
```bash
flyte delete trigger custom_cron my_task_env.custom_task --project --domain
```
## Schedule time zones
### Setting time zone for a Cron schedule
Cron expressions are by default in UTC, but it's possible to specify custom time zones like so:
```
sf_trigger = flyte.Trigger(
"sf_tz",
flyte.Cron(
"0 9 * * *", timezone="America/Los_Angeles"
), # Every day at 9 AM PT
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
nyc_trigger = flyte.Trigger(
"nyc_tz",
flyte.Cron(
"1 12 * * *", timezone="America/New_York"
), # Every day at 12:01 PM ET
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
The above two schedules will fire 1 minute apart, at 9 AM PT and 12:01 PM ET respectively.
### `flyte.TriggerTime` is always in UTC
The `flyte.TriggerTime` value is always in UTC. For timezone-aware logic, convert as needed:
```
@env.task(triggers=flyte.Trigger.minutely(trigger_time_input_key="utc_trigger_time", name="timezone_trigger"))
def timezone_task(utc_trigger_time: datetime) -> str:
local_time = utc_trigger_time.replace(tzinfo=timezone.utc).astimezone()
return f"Task fired at {utc_trigger_time} UTC ({local_time} local)"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py*
### Daylight Savings Time behavior
When Daylight Savings Time (DST) begins and ends, it can impact when the scheduled execution begins.
On the day DST begins, time jumps from 2:00AM to 3:00AM, which means the time of 2:30AM won't exist. In this case, the trigger will not fire until the next 2:30AM, which is the next day.
On the day DST ends, the hour from 1:00AM to 2:00AM repeats, which means the time of 1:30AM will exist twice. If the schedule above was instead set for 1:30AM, it would only run once, on the first occurrence of 1:30AM.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/interruptible-tasks-and-queues ===
# Interruptible tasks and queues
## Interruptible tasks
Cloud providers offer discounted compute instances (AWS Spot Instances, GCP Preemptible VMs)
that can be reclaimed at any time. These instances are significantly cheaper than on-demand
instances but come with the risk of preemption.
Setting `interruptible=True` allows Flyte to schedule the task on these spot/preemptible instances
for cost savings:
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
interruptible=True,
)
@env.task
def train_model(data: list) -> dict:
return {"accuracy": 0.95}
```
### Setting at different levels
`interruptible` can be set at the `TaskEnvironment` level, the `@env.task` decorator level,
and at the `task.override()` invocation level. The more specific level always takes precedence.
This lets you set a default at the environment level and override per-task:
```python
import flyte
# All tasks in this environment are interruptible by default
env = flyte.TaskEnvironment(
name="my_env",
interruptible=True,
)
# This task uses the environment default (interruptible)
@env.task
def preprocess(data: list) -> list:
return [x * 2 for x in data]
# This task overrides to non-interruptible (critical, should not be preempted)
@env.task(interruptible=False)
def save_results(results: dict) -> str:
return "saved"
```
You can also override at invocation time:
```python
@env.task
async def main(data: list) -> str:
processed = preprocess(data=data)
# Run this specific invocation as non-interruptible
return save_results.override(interruptible=False)(results={"data": processed})
```
### Behavior on preemption
When a spot instance is reclaimed, the task is terminated and rescheduled.
Combine `interruptible=True` with [retries](./retries-and-timeouts) to handle preemptions gracefully:
```python
@env.task(interruptible=True, retries=3)
def train_model(data: list) -> dict:
return {"accuracy": 0.95}
```
> [!NOTE]
> Retries due to spot preemption do not count against the user-configured retry budget.
> System retries (for preemptions and other system-level failures) are tracked separately.
## Queues
Queues are named routing labels that map tasks to specific resource pools or execution clusters
in your infrastructure.
Setting a queue directs the task to the corresponding compute partition:
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
queue="gpu-pool",
)
@env.task
def train_model(data: list) -> dict:
return {"accuracy": 0.95}
```
### Setting at different levels
`queue` can be set at the `TaskEnvironment` level, the `@env.task` decorator level,
and at the `task.override()` invocation level. The more specific level takes precedence.
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
queue="default-pool",
)
# Uses environment-level queue ("default-pool")
@env.task
def preprocess(data: list) -> list:
return [x * 2 for x in data]
# Overrides to a different queue
@env.task(queue="gpu-pool")
def train_model(data: list) -> dict:
return {"accuracy": 0.95}
```
If no queue is specified at any level, the default queue is used.
> [!NOTE]
> Queues are configured as part of your Flyte deployment by your platform administrator.
> The available queue names depend on your infrastructure setup.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/task-plugins ===
# Task plugins
Flyte tasks are pluggable by design, allowing you to extend task execution beyond simple containers to support specialized compute frameworks and integrations.
## Default Execution: Containers
By default, Flyte tasks execute as single containers in Kubernetes. When you decorate a function with `@env.task`, Flyte packages your code into a container and runs it on the cluster. For more advanced scenarios requiring multiple containers in a single pod (such as sidecars for logging or data mounting), you can use [pod templates](./pod-templates), which allow you to customize the entire Kubernetes pod specification.
## Compute Plugins
Beyond native container execution, Flyte provides **compute plugins** that enable you to run distributed computing frameworks directly on Kubernetes. These plugins create ephemeral clusters specifically for your task execution, spinning them up on-demand and tearing them down when complete.
### Available Compute Plugins
Flyte supports several popular distributed computing frameworks through compute plugins:
- **Spark**: Run Apache Spark jobs using the Spark operator
- **Ray**: Execute Ray workloads for distributed Python applications and ML training
- **Dask**: Scale Python workflows with Dask distributed
- **PyTorch**: Run distributed training jobs using PyTorch and Kubeflow's training operator
### How Compute Plugins Work
Compute plugins create temporary, isolated clusters within the same Kubernetes environment as Flyte:
1. **Ephemeral clusters**: Each task execution gets its own cluster, spun up on-demand
2. **Kubernetes operators**: Flyte leverages specialized Kubernetes operators (Spark operator, Ray operator, etc.) to manage cluster lifecycle
3. **Native containerization**: The same container image system used for regular tasks works seamlessly with compute plugins
4. **Per-environment configuration**: You can define the cluster shape (number of workers, resources, etc.) using `plugin_config` in your `TaskEnvironment`
### Using Compute Plugins
To use a compute plugin, you need to:
1. **Install the plugin package**: Each plugin has a corresponding Python package (e.g., `flyteplugins-ray` for Ray)
2. **Configure the TaskEnvironment**: Set the `plugin_config` parameter with the plugin-specific configuration
3. **Write your task**: Use the framework's native APIs within your task function
#### Example: Ray Plugin
Here's how to run a distributed Ray task:
```python
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte
# Define your Ray computation
@ray.remote
def compute_square(x):
return x * x
# Configure the Ray cluster
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-workers", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
# Create a task environment with Ray plugin configuration
image = (
flyte.Image.from_debian_base(name="ray")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
)
# Use the Ray cluster in your task
@ray_env.task
async def distributed_compute(n: int = 10) -> list[int]:
futures = [compute_square.remote(i) for i in range(n)]
return ray.get(futures)
```
When this task runs, Flyte will:
1. Spin up a Ray cluster with 1 head node and 2 worker nodes
2. Execute your task code in the Ray cluster
3. Tear down the cluster after completion
### Using Plugins on Union
Most compute plugins are enabled by default on Union or can be enabled upon request. Contact your Account Manager to confirm plugin availability or request specific plugins for your deployment.
## Backend Integrations
Beyond compute plugins, Flyte also supports **integrations** with external SaaS services and internal systems through **connectors**. These allow you to seamlessly interact with:
- **Data warehouses**: Snowflake, BigQuery, Redshift
- **Data platforms**: Databricks
- **Custom services**: Your internal APIs and services
Connectors enable Flyte to delegate task execution to these external systems while maintaining Flyte's orchestration, observability, and data lineage capabilities. See the **Configure tasks > Task plugins > connectors documentation** for more details on available integrations.
## Next Steps
For detailed guides on each compute plugin, including configuration options, best practices, and advanced examples, see the **Configure tasks > Task plugins > Plugins section** of the documentation. Each plugin guide covers:
- Installation and setup
- Configuration options
- Resource management
- Advanced use cases
- Troubleshooting tips
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/additional-task-settings ===
# Additional task settings
This page covers task configuration parameters that do not have their own dedicated page:
naming and metadata, default inputs, environment variables, and inline I/O thresholds.
For the full list of all task configuration parameters, see [Configure tasks](./_index).
## Naming and metadata
### `name`
The `name` parameter on `TaskEnvironment` is required.
It is combined with each task function name to form the fully-qualified task name.
For example, if you define a `TaskEnvironment` with `name="my_env"` and a task function `my_task`,
the fully-qualified task name is `my_env.my_task`.
The `name` must use `snake_case` or `kebab-case` and is immutable once set.
### `short_name`
The `short_name` parameter on `@env.task` (and `override()`) overrides the display name of a task in the UI graph view.
By default, the display name is the Python function name.
Overriding `short_name` does not change the fully-qualified task name.
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task(short_name="Train Model")
def train(data: list) -> dict:
return {"accuracy": 0.95}
```
### `description`
The `description` parameter on `TaskEnvironment` provides a description of the task environment (max 255 characters).
It is used for organizational purposes and can be viewed in the UI.
### `docs`
The `docs` parameter on `@env.task` accepts a `Documentation` object.
If not set explicitly, the documentation is auto-extracted from the task function's docstring.
```python
import flyte
from flyte import Documentation
env = flyte.TaskEnvironment(name="my_env")
@env.task(docs=Documentation(description="Trains a model on the given dataset."))
def train(data: list) -> dict:
"""This docstring is used if docs is not set explicitly."""
return {"accuracy": 0.95}
```
### `report`
The `report` parameter on `@env.task` controls whether an HTML report is generated for the task.
See [Reports](../task-programming/reports) for details.
### `links`
The `links` parameter on `@env.task` (and `override()`) attaches clickable URLs to tasks in the UI.
Use links to connect tasks to external tools like experiment trackers, monitoring dashboards, or logging systems.
Links are defined by implementing the [`Link`](../../api-reference/flyte-sdk/packages/flyte/link) protocol.
See [Links](../task-programming/links) for full details on creating and using links.
## Default inputs
Task functions support Python default parameter values. When a task parameter has a default, callers can omit it and the default is used.
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process(data: list, batch_size: int = 32, verbose: bool = False) -> dict:
# batch_size defaults to 32, verbose defaults to False
...
```
When running via `flyte run`, parameters with defaults are optional:
```bash
# Uses defaults for batch_size and verbose
flyte run my_file.py process --data '[1, 2, 3]'
# Override a default
flyte run my_file.py process --data '[1, 2, 3]' --batch-size 64
```
When invoking programmatically, Python's normal default argument rules apply:
```python
result = flyte.run(process, data=[1, 2, 3]) # batch_size=32, verbose=False
result = flyte.run(process, data=[1, 2, 3], batch_size=64) # override
```
Defaults are part of the task's input schema and are visible in the UI when viewing the task.
## Environment variables
The `env_vars` parameter on `TaskEnvironment` injects plain-text environment variables into the task container.
It accepts a `Dict[str, str]`.
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
env_vars={
"LOG_LEVEL": "DEBUG",
"API_ENDPOINT": "https://api.example.com",
},
)
@env.task
def my_task() -> str:
import os
return os.environ["API_ENDPOINT"]
```
Environment variables can be overridden at the `task.override()` invocation level
(unless `reusable` is in effect).
Use `env_vars` for non-sensitive configuration values.
For sensitive values like API keys and credentials, use [`secrets`](./secrets) instead.
## Inline I/O threshold
The `max_inline_io_bytes` parameter on `@env.task` (and `override()`) controls the maximum
size for data passed directly in the task request and response
(e.g., primitives, strings, dictionaries).
Data exceeding this threshold raises an `InlineIOMaxBytesBreached` error.
The default value is 10 MiB (`10 * 1024 * 1024` bytes).
This setting does **not** affect [`flyte.io.File`, `flyte.io.Dir`](../task-programming/files-and-directories),
or [`flyte.DataFrame`](../task-programming/dataclasses-and-structures),
which are always offloaded to object storage regardless of size.
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
# Allow up to 50 MiB of inline data
@env.task(max_inline_io_bytes=50 * 1024 * 1024)
def process_large_dict(data: dict) -> dict:
return {k: v * 2 for k, v in data.items()}
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming ===
# Build tasks
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This section covers the essential programming patterns and techniques for developing robust Flyte workflows. Once you understand the basics of task configuration, these guides will help you build sophisticated, production-ready data pipelines and machine learning workflows.
## What you'll learn
The task programming section covers key patterns for building effective Flyte workflows:
**Data handling and types**
- **Build tasks > Files and directories**: Work with large datasets using Flyte's efficient file and directory types that automatically handle data upload, storage, and transfer between tasks.
- **Build tasks > DataFrames**: Pass DataFrames between tasks without downloading data into memory, with support for Pandas, Polars, PyArrow, Dask, and other DataFrame backends.
- **Build tasks > Data classes and structures**: Use Python data classes and Pydantic models as task inputs and outputs to create well-structured, type-safe workflows.
- **Build tasks > Custom context**: Use custom context to pass metadata through your task execution hierarchy without adding parameters to every task.
**Execution patterns**
- **Build tasks > Fanout**: Scale your workflows by running many tasks in parallel, perfect for processing large datasets or running hyperparameter sweeps.
- **Build tasks > Controlling parallel execution**: Limit concurrent task executions using semaphores or `flyte.map` concurrency for rate-limited APIs, GPU quotas, and resource-constrained workflows.
- **Build tasks > Human-in-the-loop**: Pause workflow execution at a checkpoint and wait for a human to provide input or approval before continuing.
- **Build tasks > Grouping actions**: Organize related task executions into logical groups for better visualization and management in the UI.
- **Build tasks > Run a bioinformatics tool**: Run arbitrary containers in any language without the Flyte SDK installed, using Flyte's copilot sidecar for seamless data flow.
- **Build tasks > Remote tasks**: Use previously deployed tasks without importing their code or dependencies, enabling team collaboration and task reuse.
- **Configure tasks > Pod templates**: Extend tasks with Kubernetes pod templates to add sidecars, volume mounts, and advanced Kubernetes configurations.
- **Build tasks > Abort and cancel actions**: Stop in-progress actions automatically, programmatically, or manually via the CLI and UI.
- **Build tasks > Regular async function (not a task)**: Advanced patterns like task forwarding and other specialized task execution techniques.
**Development and debugging**
- **Build tasks > Notebooks**: Write and iterate on workflows directly in Jupyter notebooks for interactive development and experimentation.
- **Build tasks > Test business logic directly**: Test your Flyte tasks using direct invocation for business logic or `flyte.run()` for Flyte-specific features.
- **Build tasks > Links**: Add clickable URLs to tasks in the Flyte UI, connecting them to external tools like experiment trackers and monitoring dashboards.
- **Build tasks > Reports**: Generate custom HTML reports during task execution to display progress, results, and visualizations in the UI.
- **Build tasks > Traces**: Add fine-grained observability to helper functions within your tasks for better debugging and resumption capabilities.
- **Build tasks > Error handling**: Implement robust error recovery strategies, including automatic resource scaling and graceful failure handling.
## When to use these patterns
These programming patterns become essential as your workflows grow in complexity:
- Use **fanout** when you need to process multiple items concurrently or run parameter sweeps. Use **controlling parallel execution** when you need to limit how many run at the same time.
- Implement **error handling** for production workflows that need to recover from infrastructure failures.
- Apply **grouping** to organize complex workflows with many task executions.
- Leverage **files and directories** when working with large datasets that don't fit in memory.
- Use **DataFrames** to efficiently pass tabular data between tasks across different processing engines.
- Choose **container tasks** when you need to run code in non-Python languages, use legacy containers, or execute AI-generated code in sandboxes.
- Use **remote tasks** to reuse tasks deployed by other teams without managing their dependencies.
- Apply **pod templates** when you need advanced Kubernetes features like sidecars or specialized storage configurations.
- Use **traces** to debug non-deterministic operations like API calls or ML inference.
- Use **links** to connect tasks to external tools like Weights & Biases, Grafana, or custom dashboards directly from the Flyte UI.
- Create **reports** to monitor long-running workflows and share results with stakeholders.
- Use **custom context** when you need lightweight, cross-cutting metadata to flow through your task hierarchy without becoming part of the task's logical inputs.
- Write **unit tests** to validate your task logic and ensure type transformations work correctly before deployment.
- Use **abort and cancel** to stop unnecessary actions when conditions change, such as early convergence in HPO or manual intervention.
- Use **human-in-the-loop** to insert approval gates or data collection checkpoints into automated workflows.
Each guide includes practical examples and best practices to help you implement these patterns effectively in your own workflows.
## Subpages
- **Build tasks > Files and directories**
- **Build tasks > Data classes and structures**
- **Build tasks > DataFrames**
- **Build tasks > Custom types**
- **Build tasks > Custom context**
- **Build tasks > Abort and cancel actions**
- **Build tasks > Run a bioinformatics tool**
- **Build tasks > Links**
- **Build tasks > Reports**
- **Build tasks > Notebooks**
- **Build tasks > Remote tasks**
- **Build tasks > Error handling**
- **Build tasks > Traces**
- **Build tasks > Grouping actions**
- **Build tasks > Fanout**
- **Build tasks > Controlling parallel execution**
- **Build tasks > Human-in-the-loop**
- **Build tasks > Regular async function (not a task)**
- **Build tasks > Test business logic directly**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/files-and-directories ===
# Files and directories
Flyte provides the [`flyte.io.File`](../../api-reference/flyte-sdk/packages/flyte.io/file) and
[`flyte.io.Dir`](../../api-reference/flyte-sdk/packages/flyte.io/dir) types to represent files and directories, respectively.
Together with [`flyte.io.DataFrame`](./dataframes) they constitute the *offloaded data types* - unlike [materialized types](./dataclasses-and-structures) like data classes, these pass references rather than full data content.
A variable of an offloaded type does not contain its actual data, but rather a reference to the data.
The actual data is stored in the internal blob store of your Union/Flyte instance.
When a variable of an offloaded type is first created, its data is uploaded to the blob store.
It can then be passed from task to task as a reference.
The actual data is only downloaded from the blob stored when the task needs to access it, for example, when the task calls `open()` on a `File` or `Dir` object.
This allows Flyte to efficiently handle large files and directories without needing to transfer the data unnecessarily.
Even very large data objects like video files and DNA datasets can be passed efficiently between tasks.
The `File` and `Dir` classes provide both `sync` and `async` methods to interact with the data.
## Example usage
The examples below show the basic use-cases of uploading files and directories created locally, and using them as inputs to a task.
```
import asyncio
import tempfile
from pathlib import Path
import flyte
from flyte.io import Dir, File
env = flyte.TaskEnvironment(name="files-and-folders")
@env.task
async def write_file(name: str) -> File:
# Create a file and write some content to it
with open("test.txt", "w") as f:
f.write(f"hello world {name}")
# Upload the file using flyte
uploaded_file_obj = await File.from_local("test.txt")
return uploaded_file_obj
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py*
The upload happens when the [`File.from_local`](../../api-reference/flyte-sdk/packages/flyte.io/file#from_local) command is called.
Because the upload would otherwise block execution, `File.from_local` is implemented as an `async` function.
The Flyte SDK frequently uses this class constructor pattern, so you will see it with other types as well.
This is a slightly more complicated task that calls the task above to produce `File` objects.
These are assembled into a directory and the `Dir` object is returned, also via invoking `from_local`.
```
@env.task
async def write_and_check_files() -> Dir:
coros = []
for name in ["Alice", "Bob", "Eve"]:
coros.append(write_file(name=name))
vals = await asyncio.gather(*coros)
temp_dir = tempfile.mkdtemp()
for file in vals:
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"File {file.path} contents: {contents_str}")
new_file = Path(temp_dir) / file.name
with open(new_file, "w") as out: # noqa: ASYNC230
out.write(contents_str)
print(f"Files written to {temp_dir}")
# walk the directory and ls
for path in Path(temp_dir).iterdir():
print(f"File: {path.name}")
my_dir = await Dir.from_local(temp_dir)
return my_dir
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py*
Finally, these tasks show how to use an offloaded type as an input.
Helper functions like `walk` and `open` have been added to the objects
and do what you might expect.
```
@env.task
async def check_dir(my_dir: Dir):
print(f"Dir {my_dir.path} contents:")
async for file in my_dir.walk():
print(f"File: {file.name}")
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"Contents: {contents_str}")
@env.task
async def create_and_check_dir():
my_dir = await write_and_check_files()
await check_dir(my_dir=my_dir)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(create_and_check_dir)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py*
## JSONL files
The `flyteplugins-jsonl` package extends `File` and `Dir` with JSONL-aware types: `JsonlFile` and `JsonlDir`. They add streaming record-level read and write on top of the standard file/directory capabilities, with optional [zstd](https://github.com/facebook/zstd) compression and automatic shard rotation for large datasets.
Records are serialized with [orjson](https://github.com/ijl/orjson) for high performance. Both types provide async and sync APIs where every read/write method has a `_sync` variant (e.g. `iter_records_sync()`, `writer_sync()`).
```bash
pip install flyteplugins-jsonl
```
### Setup
```
import flyte
from flyteplugins.jsonl import JsonlDir, JsonlFile
env = flyte.TaskEnvironment(
name="jsonl-examples",
image=flyte.Image.from_debian_base(name="jsonl").with_pip_packages(
"flyteplugins-jsonl"
),
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
### JsonlFile
`JsonlFile` is a `File` subclass for single JSONL files. Use its async context manager to write records incrementally without loading the entire dataset into memory:
```
@env.task
async def write_records() -> JsonlFile:
"""Write records to a single JSONL file."""
out = JsonlFile.new_remote("results.jsonl")
async with out.writer() as writer:
for i in range(500_000):
await writer.write({"id": i, "score": i * 0.1})
return out
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
Reading is equally streaming:
```
@env.task
async def read_records(data: JsonlFile) -> int:
"""Read records from a JsonlFile and return the count."""
count = 0
async for record in data.iter_records():
count += 1
return count
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
### Compression
Both `JsonlFile` and `JsonlDir` support zstd compression transparently based on the file extension. Use `.jsonl.zst` (or `.jsonl.zstd`) to enable compression:
```
@env.task
async def write_compressed() -> JsonlFile:
"""Write a zstd-compressed JSONL file.
Compression is activated by using a .jsonl.zst extension.
Both reading and writing handle compression transparently.
"""
out = JsonlFile.new_remote("results.jsonl.zst")
async with out.writer(compression_level=3) as writer:
for i in range(100_000):
await writer.write({"id": i, "compressed": True})
return out
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
Reading compressed files requires no code changes; the compression format is detected automatically from the extension.
### JsonlDir
`JsonlDir` is a `Dir` subclass that shards records across multiple JSONL files (named `part-00000.jsonl`, `part-00001.jsonl`, etc.). When a shard reaches the record count or byte size threshold, a new shard is opened automatically. This keeps individual files at a manageable size even for very large datasets:
```
@env.task
async def write_large_dataset() -> JsonlDir:
"""Write a large dataset to a sharded JsonlDir.
JsonlDir automatically rotates to a new shard file once the
current shard reaches the record or byte limit. Shards are named
part-00000.jsonl, part-00001.jsonl, etc.
"""
out = JsonlDir.new_remote("dataset/")
async with out.writer(
max_records_per_shard=100_000,
max_bytes_per_shard=256 * 1024 * 1024, # 256 MB
) as writer:
for i in range(500_000):
await writer.write({"index": i, "value": i * i})
return out
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
Compressed shards are also supported by specifying the `shard_extension`:
```
@env.task
async def write_compressed_dir() -> JsonlDir:
"""Write zstd-compressed shards by specifying the shard extension."""
out = JsonlDir.new_remote("compressed_dataset/")
async with out.writer(
shard_extension=".jsonl.zst",
max_records_per_shard=50_000,
) as writer:
for i in range(200_000):
await writer.write({"id": i, "data": f"payload-{i}"})
return out
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
Reading iterates across all shards transparently. The next shard is prefetched in the background to overlap network I/O with processing:
```
@env.task
async def sum_values(dataset: JsonlDir) -> int:
"""Read all records across all shards and compute a sum.
Iteration is transparent across shards and handles mixed
compressed/uncompressed shards automatically. The next shard is
prefetched in the background for higher throughput.
"""
total = 0
async for record in dataset.iter_records():
total += record["value"]
return total
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
If you open a writer on a directory that already contains shards, the writer detects existing shard indices and continues from the next one, making it safe to append data to an existing `JsonlDir`.
### Error handling
All read methods accept an `on_error` parameter to control how corrupt or malformed lines are handled:
- `"raise"` (default): propagate parse errors immediately
- `"skip"`: log a warning and skip corrupt lines
- A callable `(line_number, raw_line, exception) -> None` for custom handling
```
@env.task
async def read_with_error_handling(data: JsonlFile) -> int:
"""Read records, skipping any corrupt lines instead of raising."""
count = 0
async for record in data.iter_records(on_error="skip"):
count += 1
return count
@env.task
async def read_with_custom_handler(data: JsonlFile) -> int:
"""Read records with a custom error handler that collects errors."""
errors: list[dict] = []
def on_error(line_number: int, raw_line: bytes, exc: Exception) -> None:
errors.append({"line": line_number, "error": str(exc)})
count = 0
async for record in data.iter_records(on_error=on_error):
count += 1
print(f"{count} valid records, {len(errors)} errors")
return count
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
### Batch iteration
For bulk processing, both `JsonlFile` and `JsonlDir` support batched iteration. `iter_batches()` yields lists of dicts:
```
@env.task
async def process_in_batches(dataset: JsonlDir) -> int:
"""Process records in batches of dicts for bulk operations."""
total = 0
async for batch in dataset.iter_batches(batch_size=1000):
# Each batch is a list[dict]
total += len(batch)
return total
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
For analytics workloads, `iter_arrow_batches()` yields Arrow `RecordBatch` objects directly. This requires the optional `pyarrow` dependency:
```bash
pip install 'flyteplugins-jsonl[arrow]'
```
```
arrow_env = flyte.TaskEnvironment(
name="jsonl-arrow",
image=flyte.Image.from_debian_base(name="jsonl-arrow").with_pip_packages(
"flyteplugins-jsonl[arrow]"
),
)
@arrow_env.task
async def analyze_with_arrow(dataset: JsonlDir) -> float:
"""Stream records as Arrow RecordBatches for analytics.
Memory usage is bounded by batch_size β the full dataset is
never loaded into memory at once.
"""
import pyarrow as pa
batches = []
async for batch in dataset.iter_arrow_batches(batch_size=65_536):
batches.append(batch)
table = pa.Table.from_batches(batches)
mean_value = table.column("value").to_pylist()
return sum(mean_value) / len(mean_value)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/jsonl.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/dataclasses-and-structures ===
# Data classes and structures
Dataclasses and Pydantic models are fully supported in Flyte as **materialized data types**:
Structured data where the full content is serialized and passed between tasks.
Use these as you would normally, passing them as inputs and outputs of tasks.
Unlike **offloaded types** like [`DataFrame`s](./dataframes), [`File`s and `Dir`s](./files-and-directories), data class and Pydantic model data is fully serialized, stored, and deserialized between tasks.
This makes them ideal for configuration objects, metadata, and smaller structured data where all fields should be serializable.
## Example: Combining Dataclasses and Pydantic Models
This example demonstrates how data classes and Pydantic models work together as materialized data types, showing nested structures and batch processing patterns:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pydantic",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from dataclasses import dataclass
from typing import List
from pydantic import BaseModel
import flyte
env = flyte.TaskEnvironment(name="ex-mixed-structures")
@dataclass
class InferenceRequest:
feature_a: float
feature_b: float
@dataclass
class BatchRequest:
requests: List[InferenceRequest]
batch_id: str = "default"
class PredictionSummary(BaseModel):
predictions: List[float]
average: float
count: int
batch_id: str
@env.task
async def predict_one(request: InferenceRequest) -> float:
"""
A dummy linear model: prediction = 2 * feature_a + 3 * feature_b + bias(=1.0)
"""
return 2.0 * request.feature_a + 3.0 * request.feature_b + 1.0
@env.task
async def process_batch(batch: BatchRequest) -> PredictionSummary:
"""
Processes a batch of inference requests and returns summary statistics.
"""
# Process all requests concurrently
tasks = [predict_one(request=req) for req in batch.requests]
predictions = await asyncio.gather(*tasks)
# Calculate statistics
average = sum(predictions) / len(predictions) if predictions else 0.0
return PredictionSummary(
predictions=predictions,
average=average,
count=len(predictions),
batch_id=batch.batch_id
)
@env.task
async def summarize_results(summary: PredictionSummary) -> str:
"""
Creates a text summary from the prediction results.
"""
return (
f"Batch {summary.batch_id}: "
f"Processed {summary.count} predictions, "
f"average value: {summary.average:.2f}"
)
@env.task
async def main() -> str:
batch = BatchRequest(
requests=[
InferenceRequest(feature_a=1.0, feature_b=2.0),
InferenceRequest(feature_a=3.0, feature_b=4.0),
InferenceRequest(feature_a=5.0, feature_b=6.0),
],
batch_id="demo_batch_001"
)
summary = await process_batch(batch)
result = await summarize_results(summary)
return result
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataclasses-and-structures/example.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/dataframes ===
# DataFrames
By default, return values in Python are materialized - meaning the actual data is downloaded and loaded into memory. This applies to simple types like integers, as well as more complex types like DataFrames.
To avoid downloading large datasets into memory, Flyte V2 exposes [`flyte.io.dataframe`](../../api-reference/flyte-sdk/packages/flyte.io/dataframe): a thin, uniform wrapper type for DataFrame-style objects that allows you to pass a reference to the data, rather than the fully materialized contents.
The `flyte.io.DataFrame` type provides serialization support for common engines like `pandas`, `polars`, `pyarrow`, `dask`, etc.; enabling you to move data between different DataFrame backends.
## Setting up the environment and sample data
For our example we will start by setting up our task environment with the required dependencies and create some sample data.
```
from typing import Annotated
import numpy as np
import pandas as pd
import flyte
import flyte.io
env = flyte.TaskEnvironment(
"dataframe_usage",
image= flyte.Image.from_debian_base().with_pip_packages("pandas", "pyarrow", "numpy"),
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
BASIC_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"name": ["Alice", "Bob", "Charlie", "Diana", "Ethan", "Fiona", "George", "Hannah"],
"department": ["HR", "Engineering", "Engineering", "Marketing", "Finance", "Finance", "HR", "Engineering"],
"hire_date": pd.to_datetime(
["2018-01-15", "2019-03-22", "2020-07-10", "2017-11-01", "2021-06-05", "2018-09-13", "2022-01-07", "2020-12-30"]
),
}
ADDL_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"salary": [55000, 75000, 72000, 50000, 68000, 70000, np.nan, 80000],
"bonus_pct": [0.05, 0.10, 0.07, 0.04, np.nan, 0.08, 0.03, 0.09],
"full_time": [True, True, True, False, True, True, False, True],
"projects": [
["Recruiting", "Onboarding"],
["Platform", "API"],
["API", "Data Pipeline"],
["SEO", "Ads"],
["Budget", "Forecasting"],
["Auditing"],
[],
["Platform", "Security", "Data Pipeline"],
],
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
## Create a raw DataFrame
Now, let's create a task that returns a native Pandas DataFrame:
```
@env.task
async def create_raw_dataframe() -> pd.DataFrame:
return pd.DataFrame(BASIC_EMPLOYEE_DATA)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
This is the most basic use-case of how to pass DataFrames (of all kinds, not just Pandas).
We simply create the DataFrame as normal, and return it.
Because the task has been declared to return a supported native DataFrame type (in this case `pandas.DataFrame` Flyte will automatically detect it, serialize it correctly and upload it at task completion enabling it to be passed transparently to the next task.
Flyte supports auto-serialization for the following DataFrame types:
* `pandas.DataFrame`
* `pyarrow.Table`
* `dask.dataframe.DataFrame`
* `polars.DataFrame`
* `flyte.io.DataFrame` (see below)
## Create a flyte.io.DataFrame
Alternatively you can also create a `flyte.io.DataFrame` object directly from a native object with the `from_df` method:
```
@env.task
async def create_flyte_dataframe() -> Annotated[flyte.io.DataFrame, "parquet"]:
pd_df = pd.DataFrame(ADDL_EMPLOYEE_DATA)
fdf = flyte.io.DataFrame.from_df(pd_df)
return fdf
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
The `flyte.io.DataFrame` class creates a thin wrapper around objects of any standard DataFrame type. It serves as a generic "any DataFrame type" (a concept that Python itself does not currently offer).
As with native DataFrame types, Flyte will automatically serialize and upload the data at task completion.
The advantage of the unified `flyte.io.DataFrame` wrapper is that you can be explicit about the storage format that makes sense for your use case, by using an `Annotated` type where the second argument encodes format or other lightweight hints. For example, here we specify that the DataFrame should be stored as Parquet:
## Automatically convert between types
You can leverage Flyte to automatically download and convert the DataFrame between types when needed:
```
@env.task
async def join_data(raw_dataframe: pd.DataFrame, flyte_dataframe: pd.DataFrame) -> flyte.io.DataFrame:
joined_df = raw_dataframe.merge(flyte_dataframe, on="employee_id", how="inner")
return flyte.io.DataFrame.from_df(joined_df)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
This task takes two DataFrames as input. We'll pass one raw Pandas DataFrame, and one `flyte.io.DataFrame`.
Flyte automatically converts the `flyte.io.DataFrame` to a Pandas DataFrame (since we declared that as the input type) before passing it to the task.
The actual download and conversion happens only when we access the data, in this case, when we do the merge.
## Downloading DataFrames
When a task receives a `flyte.io.DataFrame`, you can request a concrete backend representation. For example, to download as a pandas DataFrame:
```
@env.task
async def download_data(joined_df: flyte.io.DataFrame):
downloaded = await joined_df.open(pd.DataFrame).all()
print("Downloaded Data:\n", downloaded)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
The `open()` call delegates to the DataFrame handler for the stored format and converts to the requested in-memory type.
## Run the example
Finally, we can define a `main` function to run the tasks defined above and a `__main__` block to execute the workflow:
```
@env.task
async def main():
raw_df = await create_raw_dataframe ()
flyte_df = await create_flyte_dataframe ()
joined_df = await join_data (raw_df, flyte_df)
await download_data (joined_df)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py*
## Polars DataFrames
The `flyteplugins-polars` package extends Flyte's DataFrame support to `polars.DataFrame` and `polars.LazyFrame`. Install it alongside the core SDK and it registers automatically β no additional configuration required.
```bash
pip install flyteplugins-polars
```
Both types are serialized as Parquet when passed between tasks, just like other DataFrame backends.
### Setup
```
import polars as pl
import flyte
env = flyte.TaskEnvironment(
name="polars-dataframes",
image=flyte.Image.from_debian_base(name="polars").with_pip_packages(
"flyteplugins-polars>=2.0.0", "polars"
),
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
EMPLOYEE_DATA = {
"employee_id": [1001, 1002, 1003, 1004, 1005, 1006],
"name": ["Alice", "Bob", "Charlie", "Diana", "Ethan", "Fiona"],
"department": ["Engineering", "Engineering", "Marketing", "Finance", "Finance", "Engineering"],
"salary": [75000, 72000, 50000, 68000, 70000, 80000],
"years_experience": [5, 4, 2, 6, 5, 7],
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/polars_dataframes.py*
### Eager DataFrames
Use `pl.DataFrame` when you want immediate evaluation. Flyte serializes it to Parquet on output and deserializes it on input:
```
@env.task
async def create_dataframe() -> pl.DataFrame:
"""Create a Polars DataFrame.
Polars DataFrames are passed between tasks as serialized Parquet files
stored in the Flyte blob store β no manual upload required.
"""
return pl.DataFrame(EMPLOYEE_DATA)
@env.task
async def filter_high_earners(df: pl.DataFrame) -> pl.DataFrame:
"""Filter and enrich a Polars DataFrame."""
return (
df.filter(pl.col("salary") > 60000)
.with_columns(
(pl.col("salary") / pl.col("years_experience")).alias("salary_per_year")
)
.sort("salary", descending=True)
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/polars_dataframes.py*
### Lazy DataFrames
Use `pl.LazyFrame` when you want to defer computation and let Polars optimize the full query plan before executing. Flyte handles serialization the same way as `pl.DataFrame`:
```
@env.task
async def create_lazyframe() -> pl.LazyFrame:
"""Create a Polars LazyFrame.
LazyFrames defer computation until collected, allowing Polars to
optimize the full query plan. They are serialized to Parquet just
like DataFrames when passed between tasks.
"""
return pl.LazyFrame(EMPLOYEE_DATA)
@env.task
async def aggregate_by_department(lf: pl.LazyFrame) -> pl.DataFrame:
"""Aggregate salary statistics by department using a LazyFrame.
The query plan is built lazily and executed only when collect() is called.
"""
return (
lf.group_by("department")
.agg(
pl.col("salary").mean().alias("avg_salary"),
pl.col("salary").max().alias("max_salary"),
pl.len().alias("headcount"),
)
.sort("avg_salary", descending=True)
.collect()
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/polars_dataframes.py*
The `collect()` call in `aggregate_by_department` is what triggers execution of the lazy plan. The `LazyFrame` passed between tasks is serialized as Parquet at that point.
### Run the example
```
@env.task
async def main():
df = await create_dataframe()
filtered = await filter_high_earners(df=df)
print("High earners:")
print(filtered)
lf = await create_lazyframe()
summary = await aggregate_by_department(lf=lf)
print("Department summary:")
print(summary)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/polars_dataframes.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/handling-custom-types ===
# Custom types
Flyte has a rich type system that handles most Python types automatically. However, there are cases where you may want to pass custom types into a run or between actions. By default, if Flyte doesn't recognize a type, it uses Python pickle to serialize the data. While this works, pickle has several drawbacks:
- **Inefficiency**: Pickle can be very inefficient for certain data types
- **Language compatibility**: Pickle is Python-specific and doesn't work with other languages
- **Version fragility**: Pickled data can break between Python versions
- **Opacity**: Pickled data appears as bytes or file links in the UI, with no automatic form generation
Consider types like Polars DataFrames or PyTorch Tensors. Using pickle for these is extremely inefficient compared to native serialization formats like Parquet or tensor-specific formats.
Flyte SDK addresses this by allowing you to create and share type extensions.
## Types of extensions
Flyte supports two types of type extensions:
1. **Type transformers**: For scalar types (integers, strings, files, directories, custom objects)
2. **DataFrame extensions**: For tabular data types that benefit from DataFrame-specific handling
DataFrame types are special because they have associated metadata (columns, schemas), can be serialized to efficient formats like Parquet, support parallel uploads from engines like Spark, and can be partitioned.
## Creating a type transformer
Type transformers convert between Python types and Flyte's internal representation. Here's how to create one for a custom `PositiveInt` type.
### Step 1: Define your custom type
```python
# custom_type.py
class PositiveInt:
"""A wrapper type that only accepts positive integers."""
def __init__(self, value: int):
if not isinstance(value, int):
raise TypeError(f"Expected int, got {type(value).__name__}")
if value <= 0:
raise ValueError(f"Expected positive integer, got {value}")
self._value = value
@property
def value(self) -> int:
return self._value
def __repr__(self) -> str:
return f"PositiveInt({self._value})"
```
### Step 2: Create the type transformer
```python
# transformer.py
from typing import Type
from flyteidl2.core import literals_pb2, types_pb2
from flyte import logger
from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
from my_transformer.custom_type import PositiveInt
class PositiveIntTransformer(TypeTransformer[PositiveInt]):
"""
Type transformer for PositiveInt that validates and transforms positive integers.
"""
def __init__(self):
super().__init__(name="PositiveInt", t=PositiveInt)
def get_literal_type(self, t: Type[PositiveInt]) -> types_pb2.LiteralType:
"""Returns the Flyte literal type for PositiveInt."""
return types_pb2.LiteralType(
simple=types_pb2.SimpleType.INTEGER,
structure=types_pb2.TypeStructure(tag="PositiveInt"),
)
async def to_literal(
self,
python_val: PositiveInt,
python_type: Type[PositiveInt],
expected: types_pb2.LiteralType,
) -> literals_pb2.Literal:
"""Converts a PositiveInt instance to a Flyte Literal."""
if not isinstance(python_val, PositiveInt):
raise TypeTransformerFailedError(
f"Expected PositiveInt, got {type(python_val).__name__}"
)
return literals_pb2.Literal(
scalar=literals_pb2.Scalar(
primitive=literals_pb2.Primitive(integer=python_val.value)
)
)
async def to_python_value(
self,
lv: literals_pb2.Literal,
expected_python_type: Type[PositiveInt]
) -> PositiveInt:
"""Converts a Flyte Literal back to a PositiveInt instance."""
if not lv.scalar or not lv.scalar.primitive:
raise TypeTransformerFailedError(
f"Cannot convert literal {lv} to PositiveInt: missing scalar primitive"
)
value = lv.scalar.primitive.integer
try:
return PositiveInt(value)
except (TypeError, ValueError) as e:
raise TypeTransformerFailedError(
f"Cannot convert value {value} to PositiveInt: {e}"
)
def guess_python_type(
self,
literal_type: types_pb2.LiteralType
) -> Type[PositiveInt]:
"""Guesses the Python type from a Flyte literal type."""
if (
literal_type.simple == types_pb2.SimpleType.INTEGER
and literal_type.structure
and literal_type.structure.tag == "PositiveInt"
):
return PositiveInt
raise ValueError(f"Cannot guess PositiveInt from literal type {literal_type}")
```
### Step 3: Register the transformer
Create a registration function that can be called to register your transformer:
```python
def register_positive_int_transformer():
"""Register the PositiveIntTransformer in the TypeEngine."""
TypeEngine.register(PositiveIntTransformer())
logger.info("Registered PositiveIntTransformer in TypeEngine")
```
## Distributing type plugins
To share your type transformer as an installable package, configure it as a Flyte plugin using entry points.
### Configure pyproject.toml
Add the entry point to your `pyproject.toml`:
```toml
[project]
name = "my_transformer"
version = "0.1.0"
description = "Custom type transformer"
requires-python = ">=3.10"
dependencies = []
[project.entry-points."flyte.plugins.types"]
my_transformer = "my_transformer.transformer:register_positive_int_transformer"
```
The entry point group `flyte.plugins.types` tells Flyte to automatically load this transformer when the package is installed.
### Automatic loading
When your plugin package is installed, Flyte automatically loads the type transformer at runtime. This happens during `flyte.init()` or `flyte.init_from_config()`.
## Controlling plugin loading
Loading many type plugins can add overhead to initialization. You can disable automatic plugin loading:
```python
import flyte
# Disable automatic loading of type transformer plugins
flyte.init(load_plugin_type_transformers=False)
```
By default, `load_plugin_type_transformers` is `True`.
## Using custom types in tasks
Once registered, use your custom type like any built-in type:
```python
import flyte
from my_transformer.custom_type import PositiveInt
env = flyte.TaskEnvironment(name="custom_types")
@env.task
async def process_positive(value: PositiveInt) -> int:
"""Process a positive integer."""
return value.value * 2
if __name__ == "__main__":
flyte.init_from_config()
# The custom type works seamlessly
run = flyte.run(process_positive, value=PositiveInt(42))
run.wait()
print(run.outputs()[0]) # 84
```
## DataFrame extensions
For tabular data types, Flyte provides a specialized extension mechanism through `flyte.io.DataFrame`. DataFrame extensions support:
- Automatic conversion to/from Parquet format
- Column metadata and schema information
- Parallel uploads from distributed engines
- Partitioning support
DataFrame extensions use encoders and decoders from `flyte.io.extend`. Documentation for creating DataFrame extensions is coming soon.
## Best practices
1. **Use specific types over pickle**: Define type transformers for any custom types used frequently in your workflows
2. **Keep transformers lightweight**: Avoid expensive operations in `to_literal` and `to_python_value`
3. **Add validation**: Validate data in your transformer to catch errors early
4. **Use meaningful tags**: The `TypeStructure.tag` helps identify your type in the Flyte UI
5. **Be judicious with plugins**: Only install the plugins you need to minimize initialization overhead
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/custom-context ===
# Custom context
Custom context provides a mechanism for implicitly passing configuration and metadata through your entire task execution hierarchy without adding parameters to every task. It is ideal for cross-cutting concerns such as tracing, environment metadata, or experiment identifiers.
Think of custom context as **execution-scoped metadata** that automatically flows from parent to child tasks.
## Overview
Custom context is an implicit keyβvalue configuration map that is automatically available to tasks during execution. It is stored in the blob store of your Union/Flyte instance together with the taskβs inputs, making it available across tasks without needing to pass it explicitly.
You can access it in a Flyte task via:
```python
flyte.ctx().custom_context
```
Custom context is fundamentally different from standard task inputs. Task inputs are explicit, strongly typed parameters that you declare as part of a taskβs signature. They directly influence the taskβs computation and therefore participate in Flyteβs caching and reproducibility guarantees.
Custom context, on the other hand, is implicit metadata. It consists only of string key/value pairs, is not part of the task signature, and does not affect task caching. Because it is injected by the Flyte runtime rather than passed as a formal input, it should be used only for environmental or contextual information, not for data that changes the logical output of a task.
## When to use it and when not to
Custom context is perfect when you need metadata, not domain data, to flow through your tasks.
Good use cases:
- Tracing IDs, span IDs
- Experiment or run metadata
- Environment region, cluster ID
- Logging correlation keys
- Feature flags
- Session IDs for 3rd-party APIs (e.g., an LLM session)
Avoid using for:
- Business/domain data
- Inputs that change task outputs
- Anything affecting caching or reproducibility
- Large blobs of data (keep it small)
It is the cleanest mechanism when you need something available everywhere, but not logically an input to the computation.
## Setting custom context
There are two ways to set custom context for a Flyte run:
1. Set it once for the entire run when you launch (`with_runcontext`) β this establishes the base context for the execution
2. Set or override it inside task code using `flyte.custom_context(...)` context manager β this changes the active context for that task block and any nested tasks called from it
Both are legitimate and complementary. The important behavioral rules to understand are:
- `with_runcontext(...)` sets the run-level base. Values provided here are available everywhere unless overridden later. Use this for metadata that should apply to most or all tasks in the run (experiment name, top-level trace id, run id, etc.).
- `flyte.custom_context(...)` is used inside task code to set or override values for that scope. It does affect nested tasks invoked while that context is active. In practice this means you can override run-level entries, add new keys for downstream tasks, or both.
- Merging & precedence: contexts are merged; when the same key appears in multiple places the most recent/innermost value wins (i.e., values set by `flyte.custom_context(...)` override the run-level values from `with_runcontext(...)` for the duration of that block).
### Run-level context
Set base metadata once when starting the run:
```
import flyte
env = flyte.TaskEnvironment("custom-context-example")
@env.task
async def leaf_task() -> str:
# Reads run-level context
print("leaf sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def root() -> str:
return await leaf_task()
if __name__ == "__main__":
flyte.init_from_config()
# Base context for the entire run
flyte.with_runcontext(custom_context={"trace_id": "root-abc", "experiment": "v1"}).run(root)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/run_context.py*
Output (every task sees the base keys unless overridden):
```bash
leaf sees: {"trace_id": "root-abc", "experiment": "v1"}
```
### Overriding inside a task (local override that affects nested tasks)
Use `flyte.custom_context(...)` inside a task to override or add keys for downstream calls:
```
@env.task
async def downstream() -> str:
print("downstream sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def parent() -> str:
print("parent initial:", flyte.ctx().custom_context)
# Override the trace_id for the nested call(s)
with flyte.custom_context(trace_id="child-override"):
val = await downstream() # downstream sees trace_id="child-override"
# After the context block, run-level values are back
print("parent after:", flyte.ctx().custom_context)
return val
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/override_context.py*
If the run was started with `{"trace_id": "root-abc"}`, this prints:
```bash
parent initial: {"trace_id": "root-abc"}
downstream sees: {"trace_id": "child-override"}
parent after: {"trace_id": "root-abc"}
```
Note that the override affected the nested downstream task because it was invoked while the `flyte.custom_context` block was active.
### Adding new keys for nested tasks
You can add keys (not just override):
```python
with flyte.custom_context(experiment="exp-blue", run_group="g-7"):
await some_task() # some_task sees both base keys + the new keys
```
## Accessing custom context
Always via the Flyte runtime:
```python
ctx = flyte.ctx().custom_context
value = ctx.get("key")
```
You can access the custom context using either `flyte.ctx().custom_context` or the shorthand `flyte.get_custom_context()`, which returns the same dictionary of key/value pairs.
Values are always strings, so parse as needed:
```python
timeout = int(ctx["timeout_seconds"])
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/abort-tasks ===
# Abort and cancel actions
When running complex workflows, you may need to stop actions that are no longer needed.
This can happen when one branch of your workflow makes others redundant, when a task fails and its siblings should not continue, or when you need to manually intervene in a running workflow.
Flyte provides three mechanisms for stopping actions:
- **Automatic cleanup**: When a root action completes, all its in-progress descendant actions are automatically aborted.
- **Programmatic cancellation**: Cancel specific `asyncio` tasks from within your workflow code.
- **External abort**: Stop individual actions via the CLI, the UI, or the API.
For background on runs and actions, see [Runs and actions](../core-concepts/runs-and-actions).
## Action lifetime
The lifetime of all actions in a [run](../core-concepts/runs-and-actions) is tied to the lifetime of the root action (the first task that was invoked).
When the root action exitsβwhether it succeeds, fails, or returns earlyβall in-progress descendant actions are automatically aborted and no new actions can be enqueued.
This means you don't need to manually clean up child actions. Flyte handles it for you.
Consider this example where `main` exits after 10 seconds, but it has spawned a `sleep_for` action that is set to run for 30 seconds:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# main = "main"
# params = "seconds = 30"
# ///
import asyncio
import flyte
env = flyte.TaskEnvironment(name="action_lifetime")
@env.task
async def do_something():
print("Doing something")
await asyncio.sleep(5)
print("Finished doing something")
@env.task
async def sleep_for(seconds: int):
print(f"Sleeping for {seconds} seconds")
try:
await asyncio.sleep(seconds)
await do_something()
except asyncio.CancelledError:
print("sleep_for was cancelled")
return
print(f"Finished sleeping for {seconds} seconds")
@env.task
async def main(seconds: int):
print("Starting main")
asyncio.create_task(sleep_for(seconds))
await asyncio.sleep(10)
print("Main finished")
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, seconds=30)
print(run.url)
run.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/abort-tasks/action_lifetime.py*
When `main` returns after 10 seconds, the `sleep_for` action (which still has 20 seconds remaining) is automatically aborted.
The `sleep_for` task receives an `asyncio.CancelledError`, giving it a chance to handle the cancellation gracefully.
## Canceling actions programmatically
As a workflow author, you can cancel specific in-progress actions by canceling their corresponding `asyncio` tasks.
This is useful in scenarios like hyperparameter optimization (HPO), where one action converges to the desired result and the remaining actions can be stopped to save compute.
To cancel actions programmatically:
1. Launch actions using `asyncio.create_task()` and retain references to the returned task objects.
2. When the desired condition is met, call `.cancel()` on the tasks you want to stop.
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# main = "main"
# params = "n = 30, f = 10.0"
# ///
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment("cancel")
@env.task
async def sleepers(f: float, n: int):
await asyncio.sleep(f)
@env.task
async def failing_task(f: float):
raise ValueError("I will fail!")
@env.task
async def main(n: int, f: float):
sleeping_tasks = []
for i in range(n):
sleeping_tasks.append(asyncio.create_task(sleepers(f, i)))
await asyncio.sleep(f)
try:
await failing_task(f)
await asyncio.gather(*sleeping_tasks)
except flyte.errors.RuntimeUserError as e:
if e.code == "ValueError":
print(f"Received ValueError, canceling {len(sleeping_tasks)} sleeping tasks")
for t in sleeping_tasks:
t.cancel()
return
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(main, 30, 10.0))
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/abort-tasks/cancel_tasks.py*
In this code:
* The `main` task launches 30 `sleepers` actions in parallel using `asyncio.create_task()`.
* It then calls `failing_task`, which raises a `ValueError`.
* The error is caught as a `flyte.errors.RuntimeUserError` (since user-raised exceptions are wrapped by Flyte).
* On catching the error, `main` cancels all sleeping tasks by calling `.cancel()` on each one, freeing their compute resources.
This pattern lets you react to runtime conditions and stop unnecessary work. For more on handling errors within workflows, see [Error handling](./error-handling).
## External abort
Sometimes you need to stop an action manually, outside the workflow code itself. You can abort individual actions using the CLI, the UI, or the API.
When an action is externally aborted, the parent action that awaits it receives a [`flyte.errors.ActionAbortedError`](../../api-reference/flyte-sdk/packages/flyte.errors/actionabortederror). You can catch this error to handle the abort gracefully.
### Aborting via the CLI
To abort a specific action:
```bash
flyte abort
```
Use `--project` and `--domain` to target a specific [project-domain pair](../projects-and-domains).
For all available options, see the [CLI reference](../../api-reference/flyte-cli#flyte-abort).
### Handling external aborts
When using `asyncio.gather()` with `return_exceptions=True`, externally aborted actions return an `ActionAbortedError` instead of raising it. This lets you inspect results and handle aborts on a per-action basis:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# main = "main"
# params = "n = 10, sleep_for = 30.0"
# ///
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment("external_abort")
@env.task
async def long_sleeper(sleep_for: float):
await asyncio.sleep(sleep_for)
@env.task
async def main(n: int, sleep_for: float) -> str:
coros = [long_sleeper(sleep_for) for _ in range(n)]
results = await asyncio.gather(*coros, return_exceptions=True)
for i, r in enumerate(results):
if isinstance(r, flyte.errors.ActionAbortedError):
print(f"Action [{i}] was externally aborted")
return "Hello World!"
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, 10, 30.0)
print(run.url)
run.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/abort-tasks/external_abort.py*
In this code:
* The `main` task launches 10 `long_sleeper` actions in parallel.
* If any action is externally aborted (via the CLI, the UI, or the API) while running, `asyncio.gather` captures the `ActionAbortedError` as a result instead of propagating it.
* The `main` task iterates over the results and logs which actions were aborted.
* Because the abort is handled, `main` can continue executing and return its result normally.
Without `return_exceptions=True`, an external abort would raise `ActionAbortedError` directly, which you can handle with a standard `try...except` block.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/container-tasks ===
Container tasks are one of Flyte's superpowers. They allow you to execute tasks using any container image without requiring the Flyte SDK to be installed in that container. This means you can run code written in any language, execute shell scripts, or even use pre-built containers pulled directly from the internet while still maintaining Flyte's data orchestration capabilities.
## What are Container Tasks?
A container task is a special type of Flyte task that executes arbitrary container images. Unlike standard `@task` decorated functions that require the Flyte SDK, container tasks can run:
- Code written in any programming language (Rust, Go, Java, R, etc.)
- Legacy containers with unsupported Python versions
- Pre-built bioinformatics or scientific computing containers
- Shell scripts and command-line tools
- Dynamically generated code in sandboxed environments
## How Data Flows In and Out
The magic of container tasks lies in Flyte's **copilot sidecar system**. When you execute a container task, Flyte:
1. Launches your specified container alongside a copilot sidecar container
2. Uses shared Kubernetes pod volumes to pass data between containers
3. Reads inputs from `input_data_dir` and writes outputs to `output_data_dir`
4. Automatically handles serialization and deserialization of typed data
This means you can construct workflows where some tasks are container tasks while others are Python functions, and data will flow seamlessly between them.
## Basic Usage
Here's a simple example that runs a shell command in an Alpine container:
```python
import flyte
from flyte.extras import ContainerTask
greeting_task = ContainerTask(
name="echo_and_return_greeting",
image=flyte.Image.from_base("alpine:3.18"),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"name": str},
outputs={"greeting": str},
command=[
"/bin/sh",
"-c",
"echo 'Hello, my name is {{.inputs.name}}.' | tee -a /var/outputs/greeting"
],
)
```
### Template Syntax for Inputs
Container tasks support template-style references to inputs using the syntax `{{.inputs.}}`. This gets replaced with the actual input value at runtime:
```python
command=["/bin/sh", "-c", "echo 'Processing {{.inputs.user_id}}' > /var/outputs/result"]
```
### Using Container Tasks in Workflows
Container tasks integrate seamlessly with Python tasks:
```python
container_env = flyte.TaskEnvironment.from_task("container_env", greeting_task)
env = flyte.TaskEnvironment(name="hello_world", depends_on=[container_env])
@env.task
async def say_hello(name: str = "flyte") -> str:
print("Hello container task")
return await greeting_task(name=name)
```
## Advanced: Passing Files and Directories
Container tasks can accept `File` and `Dir` inputs. For these types, use path-based syntax (not template syntax) in your commands:
```python
from flyte.io import File
import pathlib
code_runner = ContainerTask(
name="python_code_runner",
image="ghcr.io/astral-sh/uv:debian-slim",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script.py": File, "a": int, "b": int},
outputs={"result": int},
command=[
"/bin/sh",
"-c",
"uv run /var/inputs/script.py {{.inputs.a}} {{.inputs.b}} > /var/outputs/result"
],
)
@env.task
async def execute_script() -> int:
path = pathlib.Path(__file__).parent / "my_script.py"
script_file = await File.from_local(path)
return await code_runner(**{"script.py": script_file, "a": 10, "b": 20})
```
Note that when passing files, the input key can include the filename (e.g., `"script.py"`), and you reference it in the command as `/var/inputs/script.py`.
## Use Case: Agentic Sandbox Execution
Container tasks are perfect for running AI-generated code in isolated environments. You can generate a data analysis script dynamically and execute it safely:
```python
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
import pathlib
env = flyte.TaskEnvironment(name="agentic_sandbox")
@env.task
async def run_generated_code(script_content: str, param_a: int, param_b: int) -> int:
# Define a container task that runs arbitrary Python code
sandbox = ContainerTask(
name="code_sandbox",
image="ghcr.io/astral-sh/uv:debian-slim",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File, "a": int, "b": int},
outputs={"result": int},
command=[
"/bin/sh",
"-c",
"uv run --script /var/inputs/script {{.inputs.a}} {{.inputs.b}} > /var/outputs/result"
],
)
# Save the generated script to a temporary file
temp_path = pathlib.Path("/tmp/generated_script.py")
temp_path.write_text(script_content)
# Execute it in the sandbox
script_file = await File.from_local(temp_path)
return await sandbox(script=script_file, a=param_a, b=param_b)
```
This pattern allows you to:
- Generate code using LLMs or other AI systems
- Execute it in a controlled, isolated environment
- Capture results and integrate them back into your workflow
- Maintain full observability and reproducibility
## Use Case: Legacy and Specialized Containers
Many scientific and bioinformatics tools are distributed as pre-built containers. Container tasks let you integrate them directly:
```python
# Run a bioinformatics tool
blast_task = ContainerTask(
name="run_blast",
image="ncbi/blast:latest",
input_data_dir="/data",
output_data_dir="/results",
inputs={"query": File, "database": str},
outputs={"alignments": File},
command=[
"blastn",
"-query", "/data/query",
"-db", "{{.inputs.database}}",
"-out", "/results/alignments",
"-outfmt", "6"
],
)
# Run legacy code with an old Python version
legacy_task = ContainerTask(
name="legacy_python",
image="python:2.7", # Unsupported Python version
input_data_dir="/app/inputs",
output_data_dir="/app/outputs",
inputs={"data_file": File},
outputs={"processed": File},
command=[
"python",
"/legacy_app/process.py",
"/app/inputs/data_file",
"/app/outputs/processed"
],
)
```
## Use Case: Multi-Language Workflows
Build workflows that span multiple languages:
```python
# Rust task for high-performance computation
rust_task = ContainerTask(
name="rust_compute",
image="rust:1.75",
inputs={"n": int},
outputs={"result": int},
input_data_dir="/inputs",
output_data_dir="/outputs",
command=["./compute_binary", "{{.inputs.n}}"],
)
# Python task for orchestration
@env.task
async def multi_lang_workflow(iterations: int) -> dict:
# Call Rust task for heavy computation
computed = await rust_task(n=iterations)
# Process results in Python
processed = await python_analysis_task(computed)
return {"rust_result": computed, "analysis": processed}
```
## Configuration Options
### ContainerTask Parameters
- **name**: Unique identifier for the task
- **image**: Container image to use (string or `Image` object)
- **command**: Command to execute in the container (list of strings)
- **inputs**: Dictionary mapping input names to types
- **outputs**: Dictionary mapping output names to types
- **input_data_dir**: Directory where Flyte writes input data (default: `/var/inputs`)
- **output_data_dir**: Directory where Flyte reads output data (default: `/var/outputs`)
- **arguments**: Additional command arguments (list of strings)
- **metadata_format**: Format for metadata serialization (`"JSON"`, `"YAML"`, or `"PROTO"`)
- **local_logs**: Whether to print container logs during local execution (default: `True`)
### Supported Input/Output Types
Container tasks support all standard Flyte types:
- Primitives: `str`, `int`, `float`, `bool`
- Temporal: `datetime.datetime`, `datetime.timedelta`
- File system: `File`, `Dir`
- Complex types: dataclasses, Pydantic models (serialized as JSON/YAML/PROTO)
## Best Practices
1. **Use specific image tags**: Prefer `alpine:3.18` over `alpine:latest` for reproducibility
2. **Keep containers focused**: Each container task should do one thing well
3. **Handle errors gracefully**: Ensure your container commands exit with appropriate status codes
4. **Test locally first**: Container tasks can run locally with Docker, making debugging easier
5. **Consider image size**: Smaller images lead to faster task startup times
6. **Document input/output contracts**: Clearly specify what data flows in and out
## Local Execution
Container tasks require Docker to be installed and running on your local machine. When you run them locally, Flyte will:
1. Pull the specified image (if not already available)
2. Mount local directories for inputs and outputs
3. Stream container logs to your console
4. Extract outputs after container completion
This makes it easy to develop and test container tasks before deploying to a remote cluster.
## When to Use Container Tasks
Choose container tasks when you need to:
- Run code in languages other than Python
- Execute pre-built tools or legacy applications
- Isolate potentially unsafe code (AI-generated scripts)
- Use specific runtime environments or dependencies
- Integrate external tools without Python wrappers
- Execute shell scripts or command-line utilities
For Python code that can use the Flyte SDK, standard `@task` decorated functions are usually simpler and more efficient.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/links ===
# Links
Links let you add clickable URLs to tasks that appear in the Flyte UI. Use them to connect tasks to external tools like experiment trackers, monitoring dashboards or custom internal services.

You can attach links to tasks in two ways:
- **Statically** in the task decorator with `links=`
- **Dynamically** at call time with `task.override(links=...)`
`Link` is a Python [Protocol](https://docs.python.org/3/library/typing.html#typing.Protocol) that you subclass to define how URLs are generated. The Weights & Biases plugin provides a [built-in link implementation](../../api-reference/integrations/wandb/packages/flyteplugins.wandb/wandb) as an example.
## Creating a link
To create a link, subclass `Link` as a dataclass and implement the `get_link()` method. The method returns the URL string to display in the UI:
```python
from dataclasses import dataclass
import flyte
from flyte import Link
@dataclass
class GrafanaLink(Link):
dashboard_url: str
name: str = "Grafana"
def get_link(
self,
run_name: str,
project: str,
domain: str,
context: dict,
parent_action_name: str,
action_name: str,
pod_name: str,
**kwargs,
) -> str:
return f"{self.dashboard_url}?var-pod={pod_name}"
env = flyte.TaskEnvironment(...)
@env.task(links=(GrafanaLink(dashboard_url="https://grafana.example.com/d/abc123"),))
def my_task() -> str:
return "done"
```
The link appears as a clickable "Grafana" link in the Flyte UI for every execution of `my_task`.
## Using execution metadata
The `get_link()` method receives execution metadata that you can use to construct dynamic URLs. Here's an example modeled on the [built-in Wandb](../../integrations/wandb/_index) link that uses the `context` dict to resolve a run ID:
```python
from dataclasses import dataclass
from typing import Optional
from flyte import Link
@dataclass
class Wandb(Link):
project: str
entity: str
id: Optional[str] = None
name: str = "Weights & Biases"
def get_link(
self,
run_name: str,
project: str,
domain: str,
context: dict[str, str],
parent_action_name: str,
action_name: str,
pod_name: str,
**kwargs,
) -> str:
run_id = self.id or context.get("wandb_id", run_name)
return f"https://wandb.ai/{self.entity}/{self.project}/runs/{run_id}"
```
The `name` attribute controls the display label in the UI.
See the [`get_link()` API reference](../../api-reference/flyte-sdk/packages/flyte/link#get_link) for more details. Note that `action_name` and `pod_name` are template variables (`{{.actionName}}` and `{{.podName}}`) that are populated by the backend at runtime.
## Dynamic links with override
Use `task.override(links=...)` to set links at runtime. This is useful when link parameters depend on runtime values like run IDs or configuration:
```python
import os
import flyte
from flyteplugins.wandb import Wandb
env = flyte.TaskEnvironment(...)
WANDB_PROJECT = "my-ml-project"
WANDB_ENTITY = "my-team"
@env.task
def train_model(config: dict) -> dict:
# Training logic here
return {"accuracy": 0.95}
@env.task
async def main(wandb_id: str) -> dict:
result = train_model.override(
links=(
Wandb(
project=WANDB_PROJECT,
entity=WANDB_ENTITY,
id=wandb_id,
),
)
)(config={"lr": 0.001})
return result
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, wandb_id="my-run-id")
```
The `override` approach lets you attach links with values that are only known at runtime, such as dynamically generated run IDs.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/reports ===
# Reports
The reports feature allows you to display and update custom output in the UI during task execution.
First, you set the `report=True` flag in the task decorator. This enables the reporting feature for that task.
Within a task with reporting enabled, a [`flyte.report.Report`](../../api-reference/flyte-sdk/packages/flyte.report/report) object is created automatically.
A `Report` object contains one or more tabs, each of which contains HTML.
You can write HTML to an existing tab and create new tabs to organize your content.
Initially, the `Report` object has one tab (the default tab) with no content.
To write content:
- [`flyte.report.log()`](../../api-reference/flyte-sdk/packages/flyte.report/_index#log) appends HTML content directly to the default tab.
- [`flyte.report.replace()`](../../api-reference/flyte-sdk/packages/flyte.report/_index#replace) replaces the content of the default tab with new HTML.
To get or create a new tab:
- [`flyte.report.get_tab()`](../../api-reference/flyte-sdk/packages/flyte.report/_index#get_tab) allows you to specify a unique name for the tab, and it will return the existing tab if it already exists or create a new one if it doesn't.
It returns a `flyte.report._report.Tab`
You can `log()` or `replace()` HTML on the `Tab` object just as you can directly on the `Report` object.
Finally, you send the report to the Flyte server and make it visible in the UI:
- [`flyte.report.flush()`](../../api-reference/flyte-sdk/packages/flyte.report/_index#flush) dispatches the report.
**It is important to call this method to ensure that the data is sent**.
## A simple example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# main = "main"
# params = ""
# ///
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="reports_example")
@env.task(report=True)
async def task1():
await flyte.report.replace.aio("
")
await flyte.report.flush.aio()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(task1)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/simple.py*
Here we define a task `task1` that logs some HTML content to the default tab and creates a new tab named "Tab 2" where it logs additional HTML content.
The `flush` method is called to send the report to the backend.
## A more complex example
Here is another example.
We import the necessary modules, set up the task environment, define the main task with reporting enabled and define the data generation function:
```
import json
import random
import flyte
import flyte.report
env = flyte.TaskEnvironment(
name="globe_visualization",
)
@env.task(report=True)
async def generate_globe_visualization():
await flyte.report.replace.aio(get_html_content())
await flyte.report.flush.aio()
def generate_globe_data():
"""Generate sample data points for the globe"""
cities = [
{"city": "New York", "country": "USA", "lat": 40.7128, "lng": -74.0060},
{"city": "London", "country": "UK", "lat": 51.5074, "lng": -0.1278},
{"city": "Tokyo", "country": "Japan", "lat": 35.6762, "lng": 139.6503},
{"city": "Sydney", "country": "Australia", "lat": -33.8688, "lng": 151.2093},
{"city": "Paris", "country": "France", "lat": 48.8566, "lng": 2.3522},
{"city": "SΓ£o Paulo", "country": "Brazil", "lat": -23.5505, "lng": -46.6333},
{"city": "Mumbai", "country": "India", "lat": 19.0760, "lng": 72.8777},
{"city": "Cairo", "country": "Egypt", "lat": 30.0444, "lng": 31.2357},
{"city": "Moscow", "country": "Russia", "lat": 55.7558, "lng": 37.6176},
{"city": "Beijing", "country": "China", "lat": 39.9042, "lng": 116.4074},
{"city": "Lagos", "country": "Nigeria", "lat": 6.5244, "lng": 3.3792},
{"city": "Mexico City", "country": "Mexico", "lat": 19.4326, "lng": -99.1332},
{"city": "Bangkok", "country": "Thailand", "lat": 13.7563, "lng": 100.5018},
{"city": "Istanbul", "country": "Turkey", "lat": 41.0082, "lng": 28.9784},
{"city": "Buenos Aires", "country": "Argentina", "lat": -34.6118, "lng": -58.3960},
{"city": "Cape Town", "country": "South Africa", "lat": -33.9249, "lng": 18.4241},
{"city": "Dubai", "country": "UAE", "lat": 25.2048, "lng": 55.2708},
{"city": "Singapore", "country": "Singapore", "lat": 1.3521, "lng": 103.8198},
{"city": "Stockholm", "country": "Sweden", "lat": 59.3293, "lng": 18.0686},
{"city": "Vancouver", "country": "Canada", "lat": 49.2827, "lng": -123.1207},
]
categories = ["high", "medium", "low", "special"]
data_points = []
for city in cities:
data_point = {**city, "value": random.randint(10, 100), "category": random.choice(categories)}
data_points.append(data_point)
return data_points
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py*
We then define the HTML content for the report:
```python
def get_html_content():
data_points = generate_globe_data()
html_content = f"""
...
return html_content
"""
```
(We exclude it here due to length. You can find it in the [source file](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)).
Finally, we run the workflow:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(generate_globe_visualization)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py*
When the workflow runs, the report will be visible in the UI:

## Streaming example
Above we demonstrated reports that are sent to the UI once, at the end of the task execution.
But, you can also stream updates to the report during task execution and see the display update in real-time.
You do this by calling `flyte.report.flush()` (or specifying `do_flush=True` in `flyte.report.log()`) periodically during the task execution, instead of just at the end of the task execution
> [!NOTE]
> In the above examples we explicitly call `flyte.report.flush()` to send the report to the UI.
> In fact, this is optional since flush will be called automatically at the end of the task execution.
> For streaming reports, on the other hand, calling `flush()` periodically (or specifying `do_flush=True`
> in `flyte.report.log()`) is necessary to display the updates.
First we import the necessary modules, and set up the task environment:
```
import asyncio
import json
import math
import random
import time
from datetime import datetime
from typing import List
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="streaming_reports")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py*
Next we define the HTML content for the report:
```python
DATA_PROCESSING_DASHBOARD_HTML = """
...
"""
```
(We exclude it here due to length. You can find it in the [source file](
https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)).
Finally, we define the task that renders the report (`data_processing_dashboard`), the driver task of the workflow (`main`), and the run logic:
```
@env.task(report=True)
async def data_processing_dashboard(total_records: int = 50000) -> str:
"""
Simulates a data processing pipeline with real-time progress visualization.
Updates every second for approximately 1 minute.
"""
await flyte.report.log.aio(DATA_PROCESSING_DASHBOARD_HTML, do_flush=True)
# Simulate data processing
processed = 0
errors = 0
batch_sizes = [800, 850, 900, 950, 1000, 1050, 1100] # Variable processing rates
start_time = time.time()
while processed < total_records:
# Simulate variable processing speed
batch_size = random.choice(batch_sizes)
# Add some processing delays occasionally
if random.random() < 0.1: # 10% chance of slower batch
batch_size = int(batch_size * 0.6)
await flyte.report.log.aio("""
""", do_flush=True)
elif random.random() < 0.05: # 5% chance of error
errors += random.randint(1, 5)
await flyte.report.log.aio("""
""", do_flush=True)
else:
await flyte.report.log.aio(f"""
""", do_flush=True)
processed = min(processed + batch_size, total_records)
current_time = time.time()
elapsed = current_time - start_time
rate = int(batch_size) if elapsed < 1 else int(processed / elapsed)
success_rate = ((processed - errors) / processed) * 100 if processed > 0 else 100
# Update dashboard
await flyte.report.log.aio(f"""
""", do_flush=True)
print(f"Processed {processed:,} records, Errors: {errors}, Rate: {rate:,}"
f" records/sec, Success Rate: {success_rate:.2f}%", flush=True)
await asyncio.sleep(1) # Update every second
if processed >= total_records:
break
# Final completion message
total_time = time.time() - start_time
avg_rate = int(total_records / total_time)
await flyte.report.log.aio(f"""
π Processing Complete!
Total Records: {total_records:,}
Processing Time: {total_time:.1f} seconds
Average Rate: {avg_rate:,} records/second
Success Rate: {success_rate:.2f}%
Errors Handled: {errors}
""", do_flush=True)
print(f"Data processing completed: {processed:,} records processed with {errors} errors.", flush=True)
return f"Processed {total_records:,} records successfully"
@env.task
async def main():
"""
Main task to run both reports.
"""
await data_processing_dashboard(total_records=50000)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py*
The key to the live update ability is the `while` loop that appends Javascript to the report. The Javascript calls execute on append to the document and update it.
When the workflow runs, you can see the report updating in real-time in the UI:

=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/notebooks ===
# Notebooks
Flyte is designed to work seamlessly with Jupyter notebooks, allowing you to write and execute workflows directly within a notebook environment.
## Iterating on and running a workflow
Download the following notebook file and open it in your favorite Jupyter environment: [interactive.ipynb](../../_static/public/interactive.ipynb)
In this example we have a simple workflow defined in our notebook.
You can iterate on the code in the notebook while running each cell in turn.
Note that the [`flyte.init()`](../../api-reference/flyte-sdk/packages/flyte/_index#init) call at the top of the notebook looks like this:
```python
flyte.init(
endpoint="https://union.example.com",
org="example_org",
project="example_project",
domain="development",
)
```
You will have to adjust it to match your Union server endpoint, organization, project, and domain.
## Accessing runs and downloading logs
Similarly, you can download the following notebook file and open it in your favorite Jupyter environment: [remote.ipynb](../../_static/public/remote.ipynb)
In this example we use the `flyte.remote` package to list existing runs, access them, and download their details and logs.
For a comprehensive guide on working with runs, actions, inputs, and outputs, see [Interact with runs and actions](../task-deployment/interacting-with-runs).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/remote-tasks ===
# Remote tasks
Remote tasks let you use previously deployed tasks without importing their code or dependencies. This enables teams to share and reuse tasks without managing complex dependency chains or container images.
## Prerequisites
Remote tasks must be deployed before you can use them. See the [task deployment guide](../task-deployment/_index) for details.
## Basic usage
Use `flyte.remote.Task.get()` to reference a deployed task:
```python
import flyte
import flyte.remote
env = flyte.TaskEnvironment(name="my_env")
# Get the latest version of a deployed task
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
# Use it in your task
@env.task
async def my_task(data_path: str) -> flyte.io.DataFrame:
# Call the reference task like any other task
result = await data_processor(input_path=data_path)
return result
```
You can run this directly without deploying it:
```bash
flyte run my_workflow.py my_task --data_path s3://my-bucket/data.parquet
```
## Understanding lazy loading
Remote tasks use **lazy loading** to keep module imports fast and enable flexible client configuration. When you call `flyte.remote.Task.get()`, it returns a lazy reference that doesn't actually fetch the task from the server until the first invocation.
### When tasks are fetched
The remote task is fetched from the server only when:
- You call `flyte.run()` with the task
- You call `flyte.deploy()` with code that uses the task
- You invoke the task with the `()` operator inside another task
- You explicitly call `.fetch()` on the lazy reference
```python
import flyte.remote
# This does NOT make a network call - returns a lazy reference
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
# The task is fetched here when you invoke it
run = flyte.run(data_processor, input_path="s3://my-bucket/data.parquet")
```
### Benefits of lazy loading
**Fast module loading**: Since no network calls are made during import, your Python modules load quickly even when referencing many remote tasks.
**Late binding**: You can call `flyte.init()` after importing remote tasks, and the correct client will be bound when the task is actually invoked:
```python
import flyte
import flyte.remote
# Load remote task reference at module level
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
# Initialize the client later
flyte.init_from_config()
# The task uses the client configured above
run = flyte.run(data_processor, input_path="s3://data.parquet")
```
### Error handling
Because of lazy loading, if a referenced task doesn't exist, you won't get an error when calling `get()`. Instead, the error occurs during invocation, raising a `flyte.errors.RemoteTaskNotFoundError`:
```python
import flyte
import flyte.remote
import flyte.errors
# This succeeds even if the task doesn't exist
data_processor = flyte.remote.Task.get(
"nonexistent.task",
auto_version="latest"
)
try:
# Error occurs here during invocation
run = flyte.run(data_processor, input_path="s3://data.parquet")
except flyte.errors.RemoteTaskNotFoundError as e:
print(f"Task not found or invocation failed: {e}")
# Handle the error - perhaps use a fallback task
# or notify the user that the task needs to be deployed
```
You can also catch errors when using remote tasks within other tasks:
```python
import flyte.errors
@env.task
async def pipeline_with_fallback(data_path: str) -> dict:
try:
# Try to use the remote task
result = await data_processor(input_path=data_path)
return {"status": "success", "result": result}
except flyte.errors.RemoteTaskNotFoundError as e:
# Fallback to local processing
print(f"Remote task failed: {e}, using local fallback")
return {"status": "fallback", "result": local_process(data_path)}
except flyte.errors.RemoteTaskUsageError as e:
raise ValueError(f"Bad Usage of remote task, maybe arguments dont match!")
```
### Eager fetching with `fetch()`
While lazy loading is convenient, you can explicitly fetch a task upfront using the `fetch()` method. This is useful for:
- **Catching errors early**: Validate that the task exists before execution starts
- **Caching**: Avoid the network call on first invocation when running multiple times
- **Service initialization**: Pre-load tasks when your service starts
```python
import flyte
import flyte.remote
import flyte.errors
# Get the lazy reference
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
try:
# Eagerly fetch the task details
task_details = data_processor.fetch()
# Now the task is cached - subsequent calls won't hit the remote service
# You can pass either the original reference or task_details to flyte.run
run1 = flyte.run(data_processor, input_path="s3://data1.parquet")
run2 = flyte.run(task_details, input_path="s3://data2.parquet")
except flyte.errors.RemoteTaskNotFoundError as e:
print(f"Task not found failed at startup: {e}")
raise
except flyte.errors.RemoteTaskUsageError as e:
print(f"Task run validation failed....")
# Handle the error before any execution attempts
```
For async contexts, use `await fetch.aio()`:
```python
import flyte.remote
async def initialize_service():
processor_ref = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
try:
# Fetch asynchronously
task_details = await processor_ref.fetch.aio()
print(f"Task {task_details.name} loaded successfully")
return processor_ref # Return the cached reference
except flyte.errors.RemoteTaskNotFoundError as e:
print(f"Failed to load task: {e}")
raise
# Initialize once at service startup
cached_processor = None
async def startup():
global cached_processor
cached_processor = await initialize_service()
# Later in your service
async def process_request(data_path: str):
# The task is already cached from initialization
# No network call on first invocation
run = flyte.run(cached_processor, input_path=data_path)
return run
```
**When to use eager fetching**:
- **Service startup**: Fetch all remote tasks during initialization to validate they exist and cache them
- **Multiple invocations**: If you'll invoke the same task many times, fetch once to cache it
- **Fail-fast validation**: Catch configuration errors before execution begins
**When lazy loading is better**:
- **Single-use tasks**: If you only invoke the task once, lazy loading is simpler
- **Import-time overhead**: Keep imports fast by deferring network calls
- **Conditional usage**: If the task may not be needed, don't fetch it upfront
### Module-level vs dynamic loading
**Module-level loading (recommended)**: Load remote tasks at the module level for cleaner, more maintainable code:
```python
import flyte.remote
# Module-level - clear and maintainable
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
@env.task
async def my_task(data_path: str):
return await data_processor(input_path=data_path)
```
**Dynamic loading**: You can also load remote tasks dynamically within a task if needed:
```python
@env.task
async def dynamic_pipeline(task_name: str, data_path: str):
# Load the task based on runtime parameters
processor = flyte.remote.Task.get(
f"data_team.{task_name}",
auto_version="latest"
)
try:
result = await processor(input_path=data_path)
return result
except flyte.errors.RemoteTaskNotFoundError as e:
raise ValueError(f"Task {task_name} not found: {e}")
```
## Complete example
This example shows how different teams can collaborate using remote tasks.
### Team A: Spark environment
Team A maintains Spark-based data processing tasks:
```python
# spark_env.py
from dataclasses import dataclass
import flyte
env = flyte.TaskEnvironment(name="spark_env")
@dataclass
class AnalysisResult:
mean_value: float
std_dev: float
@env.task
async def analyze_data(data_path: str) -> AnalysisResult:
# Spark code here (not shown)
return AnalysisResult(mean_value=42.5, std_dev=3.2)
@env.task
async def compute_score(result: AnalysisResult) -> float:
# More Spark processing
return result.mean_value / result.std_dev
```
Deploy the Spark environment:
```bash
flyte deploy spark_env/
```
### Team B: ML environment
Team B maintains PyTorch-based ML tasks:
```python
# ml_env.py
from pydantic import BaseModel
import flyte
env = flyte.TaskEnvironment(name="ml_env")
class PredictionRequest(BaseModel):
feature_x: float
feature_y: float
class Prediction(BaseModel):
score: float
confidence: float
model_version: str
@env.task
async def run_inference(request: PredictionRequest) -> Prediction:
# PyTorch model inference (not shown)
return Prediction(
score=request.feature_x * 2.5,
confidence=0.95,
model_version="v2.1"
)
```
Deploy the ML environment:
```bash
flyte deploy ml_env/
```
### Team C: Orchestration
Team C builds a workflow using remote tasks from both teams without needing Spark or PyTorch dependencies:
```python
# orchestration_env.py
import flyte.remote
env = flyte.TaskEnvironment(name="orchestration")
# Reference tasks from other teams
analyze_data = flyte.remote.Task.get(
"spark_env.analyze_data",
auto_version="latest"
)
compute_score = flyte.remote.Task.get(
"spark_env.compute_score",
auto_version="latest"
)
run_inference = flyte.remote.Task.get(
"ml_env.run_inference",
auto_version="latest"
)
@env.task
async def orchestrate_pipeline(data_path: str) -> float:
# Use Spark tasks without Spark dependencies
analysis = await analyze_data(data_path=data_path)
# Access attributes from the result
# (Flyte creates a fake type that allows attribute access)
print(f"Analysis: mean={analysis.mean_value}, std={analysis.std_dev}")
data_score = await compute_score(result=analysis)
# Use ML task without PyTorch dependencies
# Pass Pydantic models as dictionaries
prediction = await run_inference(
request={
"feature_x": analysis.mean_value,
"feature_y": data_score
}
)
# Access Pydantic model attributes
print(f"Prediction: {prediction.score} (confidence: {prediction.confidence})")
return prediction.score
```
Run the orchestration task directly (no deployment needed):
**Using Python API**:
```python
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(
orchestrate_pipeline,
data_path="s3://my-bucket/data.parquet"
)
print(f"Execution URL: {run.url}")
# You can wait for the execution
run.wait()
# You can then retrieve the outputs
print(f"Pipeline result: {run.outputs()}")
```
**Using CLI**:
```bash
flyte run orchestration_env.py orchestrate_pipeline --data_path s3://my-bucket/data.parquet
```
## Invoke remote tasks in a script.
You can also run any remote task directly using a script in a similar way
```python
import flyte
import flyte.models
import flyte.remote
flyte.init_from_config()
# Fetch the task
remote_task = flyte.remote.Task.get("package-example.calculate_average", auto_version="latest")
# Create a run, note keyword arguments are required currently. In the future this will accept positional args based on the declaration order, but, we still recommend to use keyword args.
run = flyte.run(remote_task, numbers=[1.0, 2.0, 3.0])
print(f"Execution URL: {run.url}")
# you can view the phase
print(f"Current Phase: {run.phase}")
# You can wait for the execution
run.wait()
# Only available after flyte >= 2.0.0b39
print(f"Current phase: {run.phase}")
# Phases can be compared to
if run.phase == flyte.models.ActionPhase.SUCCEEDED:
print(f"Run completed!")
# You can then retrieve the outputs
print(f"Pipeline result: {run.outputs()}")
```
## Why use remote tasks?
Remote tasks solve common collaboration and dependency management challenges:
**Cross-team collaboration**: Team A has deployed a Spark task that analyzes large datasets. Team B needs this analysis for their ML pipeline but doesn't want to learn Spark internals, install Spark dependencies, or build Spark-enabled container images. With remote tasks, Team B simply references Team A's deployed task.
**Platform reusability**: Platform teams can create common, reusable tasks (data validation, feature engineering, model serving) that other teams can use without duplicating code or managing complex dependencies.
**Microservices for data workflows**: Remote tasks work like microservices for long-running tasks or agents, enabling secure sharing while maintaining isolation.
## When to use remote tasks
Use remote tasks when you need to:
- Use functionality from another team without their dependencies
- Share common tasks across your organization
- Build reusable platform components
- Avoid dependency conflicts between different parts of your workflow
- Create modular, maintainable data pipelines
## How remote tasks work
### Security model
Remote tasks run in the **caller's project and domain** using the caller's compute resources, but execute with the **callee's service accounts, IAM roles, and secrets**. This ensures:
- Tasks are secure from misuse
- Resource usage is properly attributed
- Authentication and authorization are maintained
- Collaboration remains safe and controlled
### Type system
Remote tasks use Flyte's default types as inputs and outputs. Flyte's type system seamlessly translates data between tasks without requiring the original dependencies:
| Remote Task Type | Flyte Type |
|-------------------|------------|
| DataFrames (`pandas`, `polars`, `spark`, etc.) | `flyte.io.DataFrame` |
| Object store files | `flyte.io.File` |
| Object store directories | `flyte.io.Dir` |
| Pydantic models | Dictionary (Flyte creates a representation) |
Any DataFrame type (pandas, polars, spark) automatically becomes `flyte.io.DataFrame`, allowing seamless data exchange between tasks using different DataFrame libraries. You can also write custom integrations or explore Flyte's plugin system for additional types.
For Pydantic models specifically, you don't need the exact model locally. Pass a dictionary as input, and Flyte will handle the translation.
## Versioning options
Reference tasks support flexible versioning:
**Specific version**:
```python
task = flyte.remote.Task.get(
"team_a.process_data",
version="v1.2.3"
)
```
**Latest version** (`auto_version="latest"`):
```python
# Always use the most recently deployed version
task = flyte.remote.Task.get(
"team_a.process_data",
auto_version="latest"
)
```
**Current version** (`auto_version="current"`):
```python
# Use the same version as the calling task's deployment
# Useful when all environments deploy with the same version
# Can only be used from within a task context
task = flyte.remote.Task.get(
"team_a.process_data",
auto_version="current"
)
```
## Customizing remote tasks
Remote tasks can be customized by overriding various properties without modifying the original deployed task. This allows you to adjust resource requirements, retry strategies, caching behavior, and more based on your specific use case.
### Available overrides
The `override()` method on remote tasks accepts the following parameters:
- **short_name** (`str`): A short name for the task instance
- **resources** (`flyte.Resources`): CPU, memory, GPU, and storage limits
- **retries** (`int | flyte.RetryStrategy`): Number of retries or retry strategy
- **timeout** (`flyte.TimeoutType`): Task execution timeout
- **env_vars** (`Dict[str, str]`): Environment variables to set
- **secrets** (`flyte.SecretRequest`): Secrets to inject
- **max_inline_io_bytes** (`int`): Maximum size for inline IO in bytes
- **cache** (`flyte.Cache`): Cache behavior and settings
- **queue** (`str`): Execution queue to use
### Override examples
**Increase resources for a specific use case**:
```python
import flyte.remote
# Get the base task
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
# Override with more resources for large dataset processing
large_data_processor = data_processor.override(
resources=flyte.Resources(
cpu="16",
memory="64Gi",
storage="200Gi"
)
)
@env.task
async def process_large_dataset(data_path: str):
# Use the customized version
return await large_data_processor(input_path=data_path)
```
**Add retries and timeout**:
```python
# Override with retries and timeout for unreliable operations
reliable_processor = data_processor.override(
retries=3,
timeout="2h"
)
@env.task
async def robust_pipeline(data_path: str):
return await reliable_processor(input_path=data_path)
```
**Configure caching**:
```python
# Override cache settings
cached_processor = data_processor.override(
cache=flyte.Cache(
behavior="override",
version_override="v2",
serialize=True
)
)
```
**Set environment variables and secrets**:
```python
# Override with custom environment and secrets
custom_processor = data_processor.override(
env_vars={
"LOG_LEVEL": "DEBUG",
"REGION": "us-west-2"
},
secrets=flyte.SecretRequest(
secrets={"api_key": "my-secret-key"}
)
)
```
**Multiple overrides**:
```python
# Combine multiple overrides
production_processor = data_processor.override(
short_name="prod_spark_analyzer",
resources=flyte.Resources(cpu="8", memory="32Gi"),
retries=5,
timeout="4h",
env_vars={"ENV": "production"},
queue="high-priority"
)
@env.task
async def production_pipeline(data_path: str):
return await production_processor(input_path=data_path)
```
### Chain overrides
You can chain multiple `override()` calls to incrementally adjust settings:
```python
# Start with base task
processor = flyte.remote.Task.get("data_team.analyzer", auto_version="latest")
# Add resources
processor = processor.override(resources=flyte.Resources(cpu="4", memory="16Gi"))
# Add retries for production
if is_production:
processor = processor.override(retries=5, timeout="2h")
# Use the customized task
result = await processor(input_path="s3://data.parquet")
```
## Best practices
### 1. Use meaningful task names
Remote tasks are accessed by name, so use clear, descriptive naming:
```python
# Good
customer_segmentation = flyte.remote.Task.get("ml_platform.customer_segmentation")
# Avoid
task1 = flyte.remote.Task.get("team_a.task1")
```
### 2. Document task interfaces
Since remote tasks abstract away implementation details, clear documentation of inputs, outputs, and behavior is essential:
```python
@env.task
async def process_customer_data(
customer_ids: list[str],
date_range: tuple[str, str]
) -> flyte.io.DataFrame:
"""
Process customer data for the specified date range.
Args:
customer_ids: List of customer IDs to process
date_range: Tuple of (start_date, end_date) in YYYY-MM-DD format
Returns:
DataFrame with processed customer features
"""
...
```
### 3. Prefer module-level loading
Load remote tasks at the module level rather than inside functions for cleaner code:
```python
import flyte.remote
# Good - module level
data_processor = flyte.remote.Task.get("team.processor", auto_version="latest")
@env.task
async def my_task(data: str):
return await data_processor(input=data)
```
This approach:
- Makes dependencies clear and discoverable
- Reduces code duplication
- Works well with lazy loading (no performance penalty)
Dynamic loading within tasks is also supported when you need runtime flexibility.
### 4. Handle versioning thoughtfully
- Use `auto_version="latest"` during development for rapid iteration
- Use specific versions in production for stability and reproducibility
- Use `auto_version="current"` when coordinating multienvironment deployments
### 5. Deploy remote tasks first
Always deploy the remote tasks before using them. Tasks that reference them can be run directly without deployment:
Deploy the remote task environments first:
```bash
flyte deploy spark_env/
flyte deploy ml_env/
```
Then run the orchestration task directly (no deployment needed):
```bash
flyte run orchestration_env.py orchestrate_pipeline
```
If you want to deploy the orchestration task as well (for scheduled runs or to be referenced by other tasks), deploy it after its dependencies:
```bash
flyte deploy orchestration_env/
```
## Limitations
1. **Lazy error detection**: Because of lazy loading, errors about missing or invalid tasks only occur during invocation, not when calling `get()`. You'll receive a `flyte.errors.RemoteTaskNotFoundError` if the task doesn't exist and `flyte.errors.RemoteTaskUsageError` if it can't be invoked in the way you are passing either arguments or overrides.
2. **Type fidelity**: While Flyte translates types seamlessly, you work with Flyte's representation of Pydantic models, not the exact original types
3. **Deployment order**: Referenced tasks must be deployed before tasks that reference them can be invoked
4. **Context requirement**: Using `auto_version="current"` requires running within a task context
5. **Dictionary inputs**: Pydantic models must be passed as dictionaries, which loses compile-time type checking
6. **No positional arguments**: Remote tasks currently only support keyword arguments (this may change in future versions)
## Next steps
- Learn about [task deployment](../task-deployment/_index)
- Explore [task environments and configuration](../task-configuration/_index)
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/error-handling ===
# Error handling
One of the key features of Flyte 2 is the ability to recover from user-level errors in a workflow execution.
This includes out-of-memory errors and other exceptions.
In a distributed system with heterogeneous compute, certain types of errors are expected and even, in a sense, acceptable.
Flyte 2 recognizes this and allows you to handle them gracefully as part of your workflow logic.
This ability is a direct result of the fact that workflows are now written in regular Python,
giving you with all the power and flexibility of Python error handling.
Let's look at an example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# main = "main"
# params = ""
# ///
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="fail", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
@env.task
async def main() -> int:
try:
await oomer(2)
except flyte.errors.OOMError as e:
print(f"Failed with oom trying with more resources: {e}, of type {type(e)}, {e.code}")
try:
await oomer.override(resources=flyte.Resources(cpu=1, memory="1Gi"))(5)
except flyte.errors.OOMError as e:
print(f"Failed with OOM Again giving up: {e}, of type {type(e)}, {e.code}")
raise e
finally:
await always_succeeds()
return await always_succeeds()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/error-handling/error_handling.py*
In this code, we do the following:
* Import the necessary modules
* Set up the task environment. Note that we define our task environment with a resource allocation of 1 CPU and 250 MiB of memory.
* Define two tasks: one that will intentionally cause an out-of-memory (OOM) error, and another that will always succeed.
* Define the main task (the top level workflow task) that will handle the failure recovery logic.
The top `try...catch` block attempts to run the `oomer` task with a parameter that is likely to cause an OOM error.
If the error occurs, it catches the [`flyte.errors.OOMError`](../../api-reference/flyte-sdk/packages/flyte.errors/oomerror) and attempts to run the `oomer` task again with increased resources.
This type of dynamic error handling allows you to gracefully recover from user-level errors in your workflows.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/traces ===
# Traces
The `@flyte.trace` decorator provides fine-grained observability and resumption capabilities for functions called within your Flyte workflows.
Traces are used on **helper functions** that tasks call to perform specific operations like API calls, data processing, or computations.
Traces are particularly useful for [managing the challenges of non-deterministic behavior in workflows](../flyte-2/considerations#non-deterministic-behavior), allowing you to track execution details and resume from failures.
## What are traced functions for?
At the top level, Flyte workflows are composed of **tasks**. But it is also common practice to break down complex task logic into smaller, reusable functions by defining helper functions that tasks call to perform specific operations.
Any helper functions defined or imported into the same file as a task definition are automatically uploaded to the Flyte environment alongside the task when it is deployed.
At the task level, observability and resumption of failed executions is provided by caching, but what if you want these capabilities at a more granular level, for the individual operations that tasks perform?
This is where **traced functions** come in. By decorating helper functions with `@flyte.trace`, you enable:
- **Detailed observability**: Track execution time, inputs/outputs, and errors for each function call.
- **Fine-grained resumption**: If a workflow fails, resume from the last successful traced function instead of re-running the entire task.
Each traced function is effectively a checkpoint within its task.
Here is an example:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("env")
@flyte.trace
async def call_llm(prompt: str) -> str:
await asyncio.sleep(0.1)
return f"LLM response for: {prompt}"
@flyte.trace
async def process_data(data: str) -> dict:
await asyncio.sleep(0.2)
return {"processed": data, "status": "completed"}
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await call_llm(f"Generate research plan for: {topic}")
processed_data = await process_data(llm_result)
return {"topic": topic, "result": processed_data}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/task_vs_trace.py*
## What Gets Traced
Traces capture detailed execution information:
- **Execution time**: How long each function call takes.
- **Inputs and outputs**: Function parameters and return values.
- **Checkpoints**: State that enables workflow resumption.
### Errors are not recorded
Only successful trace executions are recorded in the checkpoint system. When a traced function fails, the exception propagates up to your task code where you can handle it with standard error handling patterns.
### Supported Function Types
The trace decorator works with:
- **Asynchronous functions**: Functions defined with `async def`.
- **Generator functions**: Functions that `yield` values.
- **Async generators**: Functions that `async yield` values.
> [!NOTE]
> Currently tracing only works for asynchronous functions. Tracing of synchronous functions is coming soon.
```
@flyte.trace
async def async_api_call(topic: str) -> dict:
# Asynchronous API call
await asyncio.sleep(0.1)
return {"data": ["item1", "item2", "item3"], "status": "success"}
@flyte.trace
async def stream_data(items: list[str]):
# Async generator function for streaming
for item in items:
await asyncio.sleep(0.02)
yield f"Processing: {item}"
@flyte.trace
async def async_stream_llm(prompt: str):
# Async generator for streaming LLM responses
chunks = ["Research shows", " that machine learning", " continues to evolve."]
for chunk in chunks:
await asyncio.sleep(0.05)
yield chunk
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await async_api_call(topic)
# Collect async generator results
processed_data = []
async for item in stream_data(llm_result["data"]):
processed_data.append(item)
llm_stream = []
async for chunk in async_stream_llm(f"Summarize research on {topic}"):
llm_stream.append(chunk)
return {
"topic": topic,
"processed_data": processed_data,
"llm_summary": "".join(llm_stream)
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/function_types.py*
## Task Orchestration Pattern
The typical Flyte workflow follows this pattern:
```
@flyte.trace
async def search_web(query: str) -> list[dict]:
# Search the web and return results
await asyncio.sleep(0.1)
return [{"title": f"Article about {query}", "content": f"Content on {query}"}]
@flyte.trace
async def summarize_content(content: str) -> str:
# Summarize content using LLM
await asyncio.sleep(0.1)
return f"Summary of {len(content.split())} words"
@flyte.trace
async def extract_insights(summaries: list[str]) -> dict:
# Extract insights from summaries
await asyncio.sleep(0.1)
return {"insights": ["key theme 1", "key theme 2"], "count": len(summaries)}
@env.task
async def research_pipeline(topic: str) -> dict:
# Each helper function creates a checkpoint
search_results = await search_web(f"research on {topic}")
summaries = []
for result in search_results:
summary = await summarize_content(result["content"])
summaries.append(summary)
final_insights = await extract_insights(summaries)
return {
"topic": topic,
"insights": final_insights,
"sources_count": len(search_results)
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/pattern.py*
**Benefits of this pattern:**
- If `search_web` succeeds but `summarize_content` fails, resumption skips the search step
- Each operation is independently observable and debuggable
- Clear separation between workflow coordination (task) and execution (traced functions)
## Relationship to Caching and Checkpointing
Understanding how traces work with Flyte's other execution features:
| Feature | Scope | Purpose | Default Behavior |
|---------|-------|---------|------------------|
| **Task Caching** | Entire task execution (`@env.task`) | Skip re-running tasks with same inputs | Enabled (`cache="auto"`) |
| **Traces** | Individual helper functions | Observability and fine-grained resumption | Manual (requires `@flyte.trace`) |
| **Checkpointing** | Workflow state | Resume workflows from failure points | Automatic when traces are used |
### How They Work Together
```
@flyte.trace
async def traced_data_cleaning(dataset_id: str) -> List[str]:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.2)
return [f"cleaned_record_{i}_{dataset_id}" for i in range(100)]
@flyte.trace
async def traced_feature_extraction(data: List[str]) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.3)
return {
"features": [f"feature_{i}" for i in range(10)],
"feature_count": len(data),
"processed_samples": len(data)
}
@flyte.trace
async def traced_model_training(features: dict) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.4)
sample_count = features["processed_samples"]
# Mock accuracy based on sample count
accuracy = min(0.95, 0.7 + (sample_count / 1000))
return {
"accuracy": accuracy,
"epochs": 50,
"model_size": "125MB"
}
@env.task(cache="auto") # Task-level caching enabled
async def data_pipeline(dataset_id: str) -> dict:
# 1. If this exact task with these inputs ran before,
# the entire task result is returned from cache
# 2. If not cached, execution begins and each traced function
# creates checkpoints for resumption
cleaned_data = await traced_data_cleaning(dataset_id) # Checkpoint 1
features = await traced_feature_extraction(cleaned_data) # Checkpoint 2
model_results = await traced_model_training(features) # Checkpoint 3
# 3. If workflow fails at step 3, resumption will:
# - Skip traced_data_cleaning (checkpointed)
# - Skip traced_feature_extraction (checkpointed)
# - Re-run only traced_model_training
return {"dataset_id": dataset_id, "accuracy": model_results["accuracy"]}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/caching_vs_checkpointing.py*
### Execution Flow
1. **Task Submission**: Task is submitted with input parameters
2. **Cache Check**: Flyte checks if identical task execution exists in cache
3. **Cache Hit**: If cached, return cached result immediately (no traces needed)
4. **Cache Miss**: Begin fresh execution
5. **Trace Checkpoints**: Each `@flyte.trace` function creates resumption points
6. **Failure Recovery**: If workflow fails, resume from last successful checkpoint
7. **Task Completion**: Final result is cached for future identical inputs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/grouping-actions ===
# Grouping actions
Groups are an organizational feature in Flyte that allow you to logically cluster related task invocations (called "actions") for better visualization and management in the UI.
Groups help you organize task executions into manageable, hierarchical structures regardless of whether you're working with large fanouts or smaller, logically-related sets of operations.
## What are groups?
Groups provide a way to organize task invocations into logical units in the Flyte UI.
When you have multiple task executionsβwhether from large [fanouts](./fanout), sequential operations, or any combination of tasksβgroups help organize them into manageable units.
### The problem groups solve
Without groups, complex workflows can become visually overwhelming in the Flyte UI:
- Multiple task executions appear as separate nodes, making it hard to see the high-level structure
- Related operations are scattered throughout the workflow graph
- Debugging and monitoring becomes difficult when dealing with many individual task executions
Groups solve this by:
- **Organizing actions**: Multiple task executions within a group are presented as a hierarchical "folder" structure
- **Improving UI visualization**: Instead of many individual nodes cluttering the view, you see logical groups that can be collapsed or expanded
- **Aggregating status information**: Groups show aggregated run status (success/failure) of their contained actions when you hover over them in the UI
- **Maintaining execution parallelism**: Tasks still run concurrently as normal, but are organized for display
### How groups work
Groups are declared using the [`flyte.group`](../../api-reference/flyte-sdk/packages/flyte/_index#group) context manager.
Any task invocations that occur within the `with flyte.group()` block are automatically associated with that group:
```python
with flyte.group("my-group-name"):
# All task invocations here belong to "my-group-name"
result1 = await task_a(data)
result2 = await task_b(data)
result3 = await task_c(data)
```
The key points about groups:
1. **Context-based**: Use the `with flyte.group("name"):` context manager.
2. **Organizational tool**: Task invocations within the context are grouped together in the UI.
3. **UI folders**: Groups appear as collapsible/expandable folders in the Flyte UI run tree.
4. **Status aggregation**: Hover over a group in the UI to see aggregated success/failure information.
5. **Execution unchanged**: Tasks still execute in parallel as normal; groups only affect organization and visualization.
**Important**: Groups do not aggregate outputs. Each task execution still produces its own individual outputs. Groups are purely for organization and UI presentation.
## Common grouping patterns
### Sequential operations
Group related sequential operations that logically belong together:
```
@env.task
async def data_pipeline(raw_data: str) -> str:
with flyte.group("data-validation"):
validated_data = await process_data(raw_data, "validate_schema")
validated_data = await process_data(validated_data, "check_quality")
validated_data = await process_data(validated_data, "remove_duplicates")
with flyte.group("feature-engineering"):
features = await process_data(validated_data, "extract_features")
features = await process_data(features, "normalize_features")
features = await process_data(features, "select_features")
with flyte.group("model-training"):
model = await process_data(features, "train_model")
model = await process_data(model, "validate_model")
final_model = await process_data(model, "save_model")
return final_model
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py*
### Parallel processing with groups
Groups work well with parallel execution patterns:
```
@env.task
async def parallel_processing_example(n: int) -> str:
tasks = []
with flyte.group("parallel-processing"):
# Collect all task invocations first
for i in range(n):
tasks.append(process_item(i, "transform"))
# Execute all tasks in parallel
results = await asyncio.gather(*tasks)
# Convert to string for consistent return type
return f"parallel_results: {results}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py*
### Multi-phase workflows
Use groups to organize different phases of complex workflows:
```
@env.task
async def multi_phase_workflow(data_size: int) -> str:
# First phase: data preprocessing
preprocessed = []
with flyte.group("preprocessing"):
for i in range(data_size):
preprocessed.append(process_item(i, "preprocess"))
phase1_results = await asyncio.gather(*preprocessed)
# Second phase: main processing
processed = []
with flyte.group("main-processing"):
for result in phase1_results:
processed.append(process_item(result, "transform"))
phase2_results = await asyncio.gather(*processed)
# Third phase: postprocessing
postprocessed = []
with flyte.group("postprocessing"):
for result in phase2_results:
postprocessed.append(process_item(result, "postprocess"))
final_results = await asyncio.gather(*postprocessed)
# Convert to string for consistent return type
return f"multi_phase_results: {final_results}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py*
### Nested groups
Groups can be nested to create hierarchical organization:
```
@env.task
async def hierarchical_example(raw_data: str) -> str:
with flyte.group("data-preparation"):
cleaned_data = await process_data(raw_data, "clean_data")
split_data = await process_data(cleaned_data, "split_dataset")
with flyte.group("hyperparameter-tuning"):
best_params = await process_data(split_data, "tune_hyperparameters")
with flyte.group("model-training"):
model = await process_data(best_params, "train_final_model")
return model
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py*
### Conditional grouping
Groups can be used with conditional logic:
```
@env.task
async def conditional_processing(use_advanced_features: bool, input_data: str) -> str:
base_result = await process_data(input_data, "basic_processing")
if use_advanced_features:
with flyte.group("advanced-features"):
enhanced_result = await process_data(base_result, "advanced_processing")
optimized_result = await process_data(enhanced_result, "optimize_result")
return optimized_result
else:
with flyte.group("basic-features"):
simple_result = await process_data(base_result, "simple_processing")
return simple_result
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py*
## Key insights
Groups are primarily an organizational and UI visualization toolβthey don't change how your tasks execute or aggregate their outputs, but they help organize related task invocations (actions) into collapsible folder-like structures for better workflow management and display. The aggregated status information (success/failure rates) is visible when hovering over group folders in the UI.
Groups make your Flyte workflows more maintainable and easier to understand, especially when working with complex workflows that involve multiple logical phases or large numbers of task executions. They serve as organizational "folders" in the UI's call stack tree, allowing you to collapse sections to reduce visual distraction while still seeing aggregated status information on hover.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/fanout ===
# Fanout
Flyte is designed to scale effortlessly, allowing you to run workflows with large fanouts.
When you need to execute many tasks in parallelβsuch as processing a large dataset or running hyperparameter sweepsβFlyte provides powerful patterns to implement these operations efficiently.
## Understanding fanout
A "fanout" pattern occurs when you spawn multiple tasks concurrently.
Each task runs in its own container and contributes an output that you later collect.
The most common way to implement this is using the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.
In Flyte terminology, each individual task execution is called an "action"βthis represents a specific invocation of a task with particular inputs. When you call a task multiple times in a loop, you create multiple actions.
## Example
We start by importing our required packages, defining our Flyte environment, and creating a simple task that fetches user data from a mock API.
```
import asyncio
from typing import List, Tuple
import flyte
env = flyte.TaskEnvironment("fanout_env")
@env.task
async def fetch_data(user_id: int) -> dict:
"""Simulate fetching user data from an API - good for parallel execution."""
# Simulate network I/O delay
await asyncio.sleep(0.1)
return {
"user_id": user_id,
"name": f"User_{user_id}",
"score": user_id * 10,
"data": f"fetched_data_{user_id}"
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py*
### Parallel execution
Next we implement the most common fanout pattern, which is to collect task invocations and execute them in parallel using `asyncio.gather()`:
```
@env.task
async def parallel_data_fetching(user_ids: List[int]) -> List[dict]:
"""Fetch data for multiple users in parallel - ideal for I/O bound operations."""
tasks = []
# Collect all fetch tasks - these can run in parallel since they're independent
for user_id in user_ids:
tasks.append(fetch_data(user_id))
# Execute all fetch operations in parallel
results = await asyncio.gather(*tasks)
return results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py*
### Running the example
To actually run our example, we create a main guard that intializes Flyte and runs our main driver task:
```
if __name__ == "__main__":
flyte.init_from_config()
user_ids = [1, 2, 3, 4, 5]
r = flyte.run(parallel_data_fetching, user_ids)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py*
## How Flyte handles concurrency and parallelism
In the example we use a standard `asyncio.gather()` pattern.
When this pattern is used in a normal Python environment, the tasks would execute **concurrently** (cooperatively sharing a single thread through the event loop), but not in true **parallel** (multiple CPU cores simultaneously).
However, **Flyte transforms this concurrency model into true parallelism**. When you use `asyncio.gather()` in a Flyte task:
1. **Flyte acts as a distributed event loop**: Instead of scheduling coroutines on a single machine, Flyte schedules each task action to run in its own container across the cluster
2. **Concurrent becomes parallel**: What would be cooperative multitasking in regular Python becomes true parallel execution across multiple machines
3. **Native Python patterns**: You use familiar `asyncio` patterns, but Flyte automatically distributes the work
This means that when you write:
```python
results = await asyncio.gather(fetch_data(1), fetch_data(2), fetch_data(3))
```
Instead of three coroutines sharing one CPU, you get three separate containers running simultaneously, each with their own CPU, memory, and resources. Flyte seamlessly bridges the gap between Python's concurrency model and distributed parallel computing, allowing for massive scalability while maintaining the familiar async/await programming model.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/controlling-parallelism ===
# Controlling parallel execution
When you [fan out](./fanout) to many tasks, you often need to limit how many run at the same time.
Common reasons include rate-limited APIs, GPU quotas, database connection limits, or simply avoiding overwhelming a downstream service.
Flyte 2 provides two ways to control concurrency:
[`asyncio.Semaphore`](https://docs.python.org/3/library/asyncio-sync.html#asyncio.Semaphore) for fine-grained control,
and `flyte.map` with a built-in `concurrency` parameter for simpler cases.
## The problem: unbounded parallelism
A straightforward `asyncio.gather` launches every task at once.
If you are calling an external API that allows only a few concurrent requests, this can cause throttling or errors:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("controlling_parallelism")
@env.task
async def call_llm_api(prompt: str) -> str:
"""Simulate calling a rate-limited LLM API."""
# In a real workflow, this would call an external API.
# The API might allow only a few concurrent requests.
await asyncio.sleep(0.5)
return f"Response to: {prompt}"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/controlling-parallelism/controlling_parallelism.py*
```
@env.task
async def process_all_at_once(prompts: list[str]) -> list[str]:
"""Send all requests in parallel with no concurrency limit.
This can overwhelm a rate-limited API, causing errors or throttling.
"""
results = await asyncio.gather(*[call_llm_api(p) for p in prompts])
return list(results)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/controlling-parallelism/controlling_parallelism.py*
With eight prompts, this fires eight concurrent API calls.
That works fine when there are no limits, but will fail when the API enforces a concurrency cap.
## Using asyncio.Semaphore
An `asyncio.Semaphore` acts as a gate: only a fixed number of tasks can pass through at a time.
The rest wait until a slot opens up.
```
@env.task
async def process_batch_with_semaphore(
prompts: list[str],
max_concurrent: int = 3,
) -> list[str]:
"""Process prompts in parallel, limiting concurrency with a semaphore.
At most `max_concurrent` calls to the API run at any given time.
The remaining tasks wait until a slot is available.
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def limited_call(prompt: str) -> str:
async with semaphore:
return await call_llm_api(prompt)
results = await asyncio.gather(*[limited_call(p) for p in prompts])
return list(results)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/controlling-parallelism/controlling_parallelism.py*
The pattern is:
1. Create a semaphore with the desired limit.
2. Wrap each task call in an inner async function that acquires the semaphore before calling and releases it after.
3. Pass all wrapped calls to `asyncio.gather`.
All eight tasks are submitted immediately, but the Flyte orchestrator only allows three to run in parallel.
As each one completes, the next waiting task starts.
> [!NOTE]
> The semaphore controls how many tasks execute concurrently on the Flyte cluster.
> Each task still runs in its own container with its own resources β the semaphore simply limits how many containers are active at a time.
## Using flyte.map with concurrency
For uniform work β applying the same task to a list of inputs β `flyte.map` with the `concurrency` parameter is simpler:
CODE2
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/controlling-parallelism/controlling_parallelism.py*
This achieves the same concurrency limit with less boilerplate.
## Running the example
CODE3
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/controlling-parallelism/controlling_parallelism.py*
## When to use each approach
Use **`flyte.map(concurrency=N)`** when:
- Every item goes through the same task.
- You want the simplest possible code.
Use **`asyncio.Semaphore`** when:
- You need different concurrency limits for different task types within the same workflow.
- You want to combine concurrency control with error handling (e.g., `asyncio.gather(*tasks, return_exceptions=True)`).
- You are calling multiple different tasks in one parallel batch.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/human-in-the-loop ===
# Human-in-the-loop
Human-in-the-loop (HITL) workflows pause execution at a defined point, wait for a human to provide input or approval, and then continue based on that response. Common use cases include content moderation gates, model output review, anomaly confirmation, and manual approval steps before costly or irreversible operations.
The `flyteplugins-hitl` package provides an event-based API for this pattern. When an event is created, Flyte automatically serves a small FastAPI web app with a form where a human can submit input. The workflow then resumes with the submitted value.
```bash
pip install flyteplugins-hitl
```
Key characteristics:
- Supports `int`, `float`, `str`, and `bool` input types
- Crash-resilient: uses durable sleep so polling survives task restarts
- Configurable timeout and poll interval
- The web form is accessible from the task's report in the Flyte UI
## Setup
The task environment must declare `hitl.env` as a dependency. This makes the HITL web app available during task execution:
```
import flyte
import flyteplugins.hitl as hitl
# The task environment must declare hitl.env as a dependency.
# This makes the HITL web app available during task execution.
env = flyte.TaskEnvironment(
name="hitl-workflow",
image=flyte.Image.from_debian_base(name="hitl").with_pip_packages(
"flyteplugins-hitl>=2.0.0",
"fastapi",
"uvicorn",
"python-multipart",
),
resources=flyte.Resources(cpu="1", memory="512Mi"),
depends_on=[hitl.env],
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/human-in-the-loop/hitl.py*
## Automated task
An automated task runs first and produces a result that requires human review:
```
@env.task(report=True)
async def analyze_data(dataset: str) -> dict:
"""Automated task that produces a result requiring human review."""
# Simulate analysis
result = {
"dataset": dataset,
"row_count": 142857,
"anomalies_detected": 3,
"confidence": 0.87,
}
await flyte.report.replace.aio(
f"Analysis complete: {result['anomalies_detected']} anomalies detected "
f"(confidence: {result['confidence']:.0%})"
)
await flyte.report.flush.aio()
return result
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/human-in-the-loop/hitl.py*
## Requesting human input
Use `hitl.new_event()` to pause and wait for a human response. The `prompt` is shown on the web form. The `data_type` controls what type the submitted value is converted to before being returned:
```
@env.task(report=True)
async def request_human_review(analysis: dict) -> bool:
"""Pause and ask a human whether to proceed with the flagged records."""
event = await hitl.new_event.aio(
"review_decision",
data_type=bool,
scope="run",
prompt=(
f"Analysis found {analysis['anomalies_detected']} anomalies "
f"with {analysis['confidence']:.0%} confidence. "
"Approve for downstream processing? (true/false)"
),
)
approved: bool = await event.wait.aio()
return approved
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/human-in-the-loop/hitl.py*
When this task runs, Flyte:
1. Serves the HITL web app (if not already running)
2. Creates an event and writes a pending request to object storage
3. Displays a link to the web form in the task report
4. Polls for a response using durable sleep
5. Returns the submitted value once input is received
## Wiring it together
The main task orchestrates the automated step and the HITL gate:
```
@env.task(report=True)
async def main(dataset: str = "s3://my-bucket/data.parquet") -> str:
analysis = await analyze_data(dataset=dataset)
approved = await request_human_review(analysis=analysis)
if approved:
return "Processing approved β continuing pipeline."
else:
return "Processing rejected by reviewer β pipeline halted."
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
print(r.outputs())
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/human-in-the-loop/hitl.py*
## Event options
`hitl.new_event()` accepts the following parameters:
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `name` | `str` | β | Descriptive name shown in logs and the UI |
| `data_type` | `type` | β | Expected input type: `int`, `float`, `str`, or `bool` |
| `scope` | `str` | `"run"` | Scope of the event. Currently only `"run"` is supported |
| `prompt` | `str` | `"Please provide a value"` | Message shown on the web form |
| `timeout_seconds` | `int` | `3600` | Maximum time to wait before raising `TimeoutError` |
| `poll_interval_seconds` | `int` | `5` | How often to check for a response |
## Submitting input programmatically
In addition to the web form, input can be submitted via the event's JSON API endpoint. This is useful for automated testing or integration with external approval systems:
```bash
curl -X POST https:///submit/json \
-H "Content-Type: application/json" \
-d '{
"request_id": "",
"response_path": "",
"value": "true",
"data_type": "bool"
}'
```
The `request_id` and `response_path` are shown in the task report alongside the form URL.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/other-features ===
This section covers advanced programming patterns and techniques for working with Flyte tasks.
## Task Forwarding
When one task calls another task using the normal invocation syntax (e.g., `await inner_task(x)`), Flyte creates a durable action that's recorded in the UI with data passed through the metadata store. However, if you want to execute a task in the same Python VM without this overhead, use the `.forward()` method.
**When to use**: You want to avoid durability overhead and execute task logic directly in the current VM.
```python
import flyte
env = flyte.TaskEnvironment("my-env")
@env.task
async def inner_task(x: int) -> int:
return x + 1
@env.task
async def outer_task(x: int) -> int:
# Executes in same VM, no durable action created
v = await inner_task.forward(x=10)
# Creates a durable action, recorded in UI
return await inner_task(v)
```
The `.forward()` method works with both sync and async tasks:
```python
@env.task
def sync_inner_task(x: int) -> int:
return x + 1
@env.task
def sync_outer_task(x: int) -> int:
# Direct execution, no remote call
v = sync_inner_task.forward(x=10)
return sync_inner_task(v)
```
## Passing Tasks and Functions as Arguments
You can pass both Flyte tasks and regular Python functions as arguments to other tasks. Flyte handles this through pickling, so the code appears as pickled data in the UI.
```python
import typing
import flyte
env = flyte.TaskEnvironment("udfs")
@env.task
async def add_one_udf(x: int) -> int:
return x + 1
# Regular async function (not a task)
async def fn_add_two_udf(x: int) -> int:
return x + 2
@env.task
async def run_udf(x: int, udf: typing.Callable[[int], typing.Awaitable[int]]) -> int:
return await udf(x)
@env.task
async def main() -> list[int]:
# Pass a Flyte task as an argument
result_one = await run_udf(5, add_one_udf)
# Pass a regular function as an argument
result_two = await run_udf(5, fn_add_two_udf)
return [result_one, result_two]
```
**Note**: Both tasks and regular functions are serialized via pickling when passed as arguments.
## Custom Action Names
By default, actions in the UI use the task's function name. You can provide custom, user-friendly names using the `short_name` parameter.
### Set at Task Definition
```python
import flyte
env = flyte.TaskEnvironment("friendly_names")
@env.task(short_name="my_task")
async def some_task() -> str:
return "Hello, Flyte!"
```
### Override at Call Time
```python
@env.task(short_name="entrypoint")
async def main() -> str:
# Uses the default short_name "my_task"
s = await some_task()
# Overrides to use "my_name" for this specific action
return s + await some_task.override(short_name="my_name")()
```
This is useful when the same task is called multiple times with different contexts, making the UI more readable.
## Invoking Async Functions from Sync Tasks
When migrating from Flyte 1.x to 2.0, you may have legacy sync code that needs to call async functions. Use `nest_asyncio.apply()` to enable `asyncio.run()` within sync tasks.
```python
import asyncio
import nest_asyncio
import flyte
env = flyte.TaskEnvironment(
"async_in_sync",
image=flyte.Image.from_debian_base().with_pip_packages("nest_asyncio"),
)
# Apply at module level
nest_asyncio.apply()
async def async_function() -> str:
await asyncio.sleep(1)
return "done"
@env.task
def sync_task() -> str:
# Now you can use asyncio.run() in a sync task
return asyncio.run(async_function())
```
**Important**:
- Call `nest_asyncio.apply()` at the module level before defining tasks
- Add `nest_asyncio` to your image dependencies
- This is particularly useful during migration when you have mixed sync/async code
## Async and Sync Task Interoperability
When migrating from older sync-based code to async tasks, or when working with mixed codebases, you need to call sync tasks from async parent tasks. Flyte provides the `.aio` method on every task (even sync ones) to enable this.
### Calling Sync Tasks from Async Tasks
Every sync task automatically has an `.aio` property that returns an async-compatible version:
```python
import flyte
env = flyte.TaskEnvironment("mixed-tasks")
@env.task
def sync_task(x: int) -> str:
"""Legacy sync task"""
return f"Processed {x}"
@env.task
async def async_task(x: int) -> str:
"""New async task that calls legacy sync task"""
# Use .aio to call sync task from async context
result = await sync_task.aio(x)
return result
```
### Using with `flyte.map.aio()`
When you need to call sync tasks in parallel from an async context, use `flyte.map.aio()`:
```python
from typing import List
import flyte
env = flyte.TaskEnvironment("map-example")
@env.task
def sync_process(x: int) -> str:
"""Synchronous processing task"""
return f"Task {x}"
@env.task
async def async_main(n: int) -> List[str]:
"""Async task that maps over sync task"""
results = []
# Map over sync task from async context
async for result in flyte.map.aio(sync_process, range(n)):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
**Why this matters**: This pattern is powerful when migrating from Flyte 1.x or integrating legacy sync tasks with new async code. You don't need to rewrite all sync tasks at onceβthey can be called seamlessly from async contexts.
## Using AnyIO in Async Tasks
Flyte async tasks support `anyio` for structured concurrency as an alternative to `asyncio.gather()`.
```python
import anyio
import aioresult
import flyte
env = flyte.TaskEnvironment(
"anyio_example",
image=flyte.Image.from_debian_base().with_pip_packages("anyio", "aioresult"),
)
@env.task
async def process_item(x: int) -> int:
return x * 2
@env.task
async def batch_process(items: list[int]) -> list[int]:
captured_results = []
async with anyio.create_task_group() as tg:
# Start multiple tasks concurrently
for item in items:
captured_results.append(
aioresult.ResultCapture.start_soon(tg, process_item, item)
)
# Extract results
return [r.result() for r in captured_results]
```
**Note**: You can use anyio's task groups, timeouts, and other structured concurrency primitives within Flyte async tasks.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/unit-testing ===
Unit testing is essential for ensuring your Flyte tasks work correctly. Flyte 2.0 provides flexible testing approaches that allow you to test both your business logic and Flyte-specific features like type transformations and caching.
## Understanding Task Invocation
When working with functions decorated with `@env.task`, there are two ways to invoke them, each with different behavior:
### Direct Function Invocation
When you call a task directly like a regular Python function:
```python
result = my_task(x=10, y=20)
```
**Flyte features are NOT invoked**, including:
- Type transformations and serialization
- Caching
- Data validation
This behaves exactly like calling a regular Python function, making it ideal for testing your business logic.
### Using `flyte.run()`
When you invoke a task using `flyte.run()`:
```python
run = flyte.run(my_task, x=10, y=20)
result = run.outputs()
```
**Flyte features ARE invoked**, including:
- Type transformations and serialization
- Data validation
- Type checking (raises `flyte.errors` if types are not supported or restricted)
This allows you to test Flyte-specific behavior like serialization and caching.
## Testing Business Logic
For most unit tests, you want to verify your business logic works correctly. Use **direct function invocation** for this:
```python
import flyte
env = flyte.TaskEnvironment("my_env")
@env.task
def add(a: int, b: int) -> int:
return a + b
def test_add():
result = add(a=3, b=5)
assert result == 8
```
### Testing Async Tasks
Async tasks work the same way with direct invocation:
```python
import pytest
@env.task
async def subtract(a: int, b: int) -> int:
return a - b
@pytest.mark.asyncio
async def test_subtract():
result = await subtract(a=10, b=4)
assert result == 6
```
### Testing Nested Tasks
When tasks call other tasks, direct invocation continues to work without any Flyte overhead:
```python
@env.task
def nested(a: int, b: int) -> int:
return add(a, b) # Calls the add task directly
def test_nested():
result = nested(3, 5)
assert result == 8
```
## Testing Type Transformations and Serialization
When you need to test how Flyte handles data types, serialization, or caching, use `flyte.run()`:
```python
@pytest.mark.asyncio
async def test_add_with_flyte_run():
run = flyte.run(add, 3, 5)
assert run.outputs() == 8
```
### Testing Type Restrictions
Some types may not be supported or may be restricted. Use `flyte.run()` to test that these restrictions are enforced:
```python
from typing import Tuple
import flyte.errors
@env.task
def not_supported_types(x: Tuple[str, str]) -> str:
return x[0]
@pytest.mark.asyncio
async def test_not_supported_types():
# Direct invocation works fine
result = not_supported_types(x=("a", "b"))
assert result == "a"
# flyte.run enforces type restrictions
with pytest.raises(flyte.errors.RestrictedTypeError):
flyte.run(not_supported_types, x=("a", "b"))
```
### Testing Nested Tasks with Serialization
You can also test nested task execution with Flyte's full machinery:
```python
@pytest.mark.asyncio
async def test_nested_with_run():
run = flyte.run(nested, 3, 5)
assert run.outputs() == 8
```
## Testing Traced Functions
Functions decorated with `@flyte.trace` can be tested similarly to tasks:
```python
@flyte.trace
async def traced_multiply(a: int, b: int) -> int:
return a * b
@pytest.mark.asyncio
async def test_traced_multiply():
result = await traced_multiply(a=6, b=7)
assert result == 42
```
## Best Practices
1. **Test logic with direct invocation**: For most unit tests, call tasks directly to test your business logic without Flyte overhead.
2. **Test serialization with `flyte.run()`**: Use `flyte.run()` when you need to verify:
- Type transformations work correctly
- Data serialization/deserialization
- Caching behavior
- Type restrictions are enforced
3. **Use standard testing frameworks**: Flyte tasks work with pytest, unittest, and other Python testing frameworks.
4. **Test async tasks properly**: Use `@pytest.mark.asyncio` for async tasks and await their results.
5. **Mock external dependencies**: Use standard Python mocking techniques for external services, databases, etc.
## Quick Reference
| Test Scenario | Method | Example |
|--------------|--------|---------|
| Business logic (sync) | Direct call | `result = task(x=10)` |
| Business logic (async) | Direct await | `result = await task(x=10)` |
| Type transformations | `flyte.run()` | `r = flyte.run(task, x=10)` |
| Data serialization | `flyte.run()` | `r = flyte.run(task, x=10)` |
| Caching behavior | `flyte.run()` | `r = flyte.run(task, x=10)` |
| Type restrictions | `flyte.run()` + pytest.raises | `pytest.raises(flyte.errors.RestrictedTypeError)` |
## Example Test Suite
Here's a complete example showing different testing approaches:
```python
import pytest
import flyte
import flyte.errors
env = flyte.TaskEnvironment("test_env")
@env.task
def add(a: int, b: int) -> int:
return a + b
@env.task
async def subtract(a: int, b: int) -> int:
return a - b
# Test business logic directly
def test_add_logic():
result = add(a=3, b=5)
assert result == 8
@pytest.mark.asyncio
async def test_subtract_logic():
result = await subtract(a=10, b=4)
assert result == 6
# Test with Flyte serialization
@pytest.mark.asyncio
async def test_add_serialization():
run = flyte.run(add, 3, 5)
assert run.outputs() == 8
@pytest.mark.asyncio
async def test_subtract_serialization():
run = flyte.run(subtract, a=10, b=4)
assert run.outputs() == 6
```
## Future Improvements
The Flyte SDK team is actively working on improvements for advanced unit testing scenarios, particularly around initialization and setup for complex test cases. Additional utilities and patterns may be introduced in future releases to make unit testing even more streamlined.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment ===
# Run and deploy tasks
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
You have seen how to configure and build the tasks that compose your project.
Now you need to decide how to execute them on your Flyte backend.
Flyte offers two distinct approaches for getting your tasks onto the backend:
**Use `flyte run` when you're iterating and experimenting:**
- Quickly test changes during development
- Try different parameters or code modifications
- Debug issues without creating permanent artifacts
- Prototype new ideas rapidly
**Use `flyte deploy` when your project is ready to be formalized:**
- Freeze a stable version of your tasks for repeated use
- Share tasks with team members or across environments
- Move from experimentation to a more structured workflow
- Create a permanent reference point (not necessarily production-ready)
This section explains both approaches and when to use each one.
## Ephemeral deployment and immediate execution
The `flyte run` CLI command and the `flyte.run()` SDK function are used to **ephemerally deploy** and **immediately execute** a task on the backend in a single step.
The task can be re-run and its execution and outputs can be observed in the **Runs list** UI, but it is not permanently added to the **Tasks list** on the backend.
Let's say you have the following file called `greeting.py`:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
```
### Programmatic
You can run the task programmatically using the `flyte.run()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(greet, message="Good morning!")
print(f"Result: {result}")
```
Here we add a `__main__` block to the `greeting.py` file that initializes the Flyte SDK from the configuration file and then calls `flyte.run()` with the `greet` task and its argument.
Now you can run the `greet` task on the backend just by executing the `greeting.py` file locally as a script:
```bash
python greeting.py
```
### CLI
The general form of the command for running a task from a local file is:
```bash
flyte run
```
So, to run the `greet` task defined in the `greeting.py` file, you would run:
```bash
flyte run greeting.py greet --message "Good morning!"
```
This command:
1. **Temporarily deploys** the task environment named `greeting_env` (held by the variable `env`) that contains the `greet` task.
2. **Executes** the `greet` function with argument `message` set to `"Good morning!"`. Note that `message` is the actual parameter name defined in the function signature.
3. **Returns** the execution results and displays them in the terminal.
For more details on how `flyte run` and `flyte.run()` work under the hood, see **Run and deploy tasks > How task run works**.
## Persistent deployment
The `flyte deploy` CLI command and the `flyte.deploy()` SDK function are used to **persistently deploy** a task environment (and all its contained tasks) to the backend.
The tasks within the deployed environment will appear in the **Tasks list** UI on the backend and can then be executed multiple times without needing to redeploy them.
### Programmatic
You can deploy programmatically using the `flyte.deploy()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(env)
print(deployments[0].summary_repr())
```
Now you can deploy the `greeting_env` task environment (and therefore the `greet()` task) just by executing the `greeting.py` file locally as a script.
```bash
python greeting.py
```
### CLI
The general form of the command for deploying a task environment from a local file is:
```bash
flyte deploy
```
So, using the same `greeting.py` file as before, you can deploy the `greeting_env` task environment like this:
```bash
flyte deploy greeting.py env
```
This command deploys the task environment *assigned to the variable `env`* in the `greeting.py` file, which is the `TaskEnvironment` named `greeting_env`.
Notice that you must specify the *variable* to which the `TaskEnvironment` is assigned (`env` in this case), not the name of the environment itself (`greeting_env`).
Deploying a task environment deploys all tasks defined within it. Here, that means all functions decorated with `@env.task`.
In this case there is just one: `greet()`.
For more details on how `flyte deploy` and `flyte.deploy()` work under the hood, see **Run and deploy tasks > How task deployment works**.
## Running already deployed tasks
If you have already deployed your task environment, you can run its tasks without redeploying by using the `flyte run` CLI command or the `flyte.run()` SDK function in a slightly different way. Alternatively, you can always initiate execution of a deployed task from the UI.
### Programmatic
You can run already-deployed tasks programmatically using the `flyte.run()` function.
For example, to run the previously deployed `greet` task from the `greeting_env` environment:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env)
task = flyte.remote.Task.get("greeting_env.greet", auto_version="latest")
result = flyte.run(task, message="Good morning!")
print(f"Result: {result}")
```
When you execute this script locally, it will:
- Deploy the `greeting_env` task environment as before.
- Retrieve the already-deployed `greet` task using `flyte.remote.Task.get()`, specifying its full task reference as a string: `"greeting_env.greet"`.
- Call `flyte.run()` with the retrieved task and its argument.
For more details on how running already-deployed tasks works, see **Run and deploy tasks > How task run works > Running deployed tasks**.
### CLI
To run a permanently deployed task using the `flyte run` CLI command, use the special `deployed-task` keyword followed by the task reference in the format `{environment_name}.{task_name}`. For example, to run the previously deployed `greet` task from the `greeting_env` environment:
```bash
flyte run deployed-task greeting_env.greet --message "World"
```
Notice that now that the task environment is deployed, you use its name (`greeting_env`), not by the variable name to which it was assigned in source code (`env`).
The task environment name plus the task name (`greet`) are combined with a dot (`.`) to form the full task reference: `greeting_env.greet`.
The special `deployed-task` keyword tells the CLI that you are referring to a task that has already been deployed. In effect, it replaces the file path argument used for ephemeral runs.
When executed, this command will run the already-deployed `greet` task with argument `message` set to `"World"`. You will see the result printed in the terminal. You can also, of course, observe the execution in the **Runs list** UI.
To execute a deployed task in a different project or domain than your configured defaults, use `--run-project` and `--run-domain`:
```bash
flyte run --run-project prod-project --run-domain production deployed-task greeting_env.greet --message "World"
```
For all `flyte run` options, see **Run and deploy tasks > Run command options**.
## Configuring runs with `flyte.with_runcontext()`
Both `flyte run` and `flyte.run()` accept a range of invocation-time parameters that control where the run executes, where outputs are stored, caching behavior, and more.
Programmatically, these are set with `flyte.with_runcontext()` before calling `.run()`.
Inside a running task, `flyte.ctx()` provides read access to the same context.
For the full parameter reference, see **Run and deploy tasks > Run context**.
## Subpages
- **Run and deploy tasks > How task run works**
- **Run and deploy tasks > Interact with runs and actions**
- **Run and deploy tasks > Work with local data**
- **Run and deploy tasks > Run command options**
- **Run and deploy tasks > How task deployment works**
- **Run and deploy tasks > Deploy command options**
- **Run and deploy tasks > Code packaging for remote execution**
- **Run and deploy tasks > Deployment patterns**
- **Run and deploy tasks > Run context**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/how-task-run-works ===
# How task run works
The `flyte run` command and `flyte.run()` SDK function support three primary execution modes:
1. **Ephemeral deployment + run**: Automatically prepare task environments ephemerally and execute tasks (development shortcut)
2. **Run deployed task**: Execute permanently deployed tasks without redeployment
3. **Local execution**: Run tasks on your local machine for development and testing
Additionally, you can run deployed tasks through the Flyte/Union UI for interactive execution and monitoring.
## Ephemeral deployment + run: The development shortcut
The most common development pattern combines ephemeral task preparation and execution in a single command, automatically handling the temporary deployment process when needed.
### Programmatic
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
# Deploy and run in one step
result = flyte.run(my_task, name="World")
print(f"Result: {result}")
print(f"Execution URL: {result.url}")
```
### CLI
```bash
flyte run my_example.py my_task --name "World"
```
With explicit project and domain:
```bash
flyte run --project my-project --domain development my_example.py my_task --name "World"
```
With deployment options:
```bash
flyte run --version v1.0.0 --copy-style all my_example.py my_task --name "World"
```
**How it works:**
1. **Environment discovery**: Flyte loads the specified Python file and identifies task environments
2. **Ephemeral preparation**: Temporarily prepares the task environment for execution (similar to deployment but not persistent)
3. **Task execution**: Immediately runs the specified task with provided arguments in the ephemeral environment
4. **Result return**: Returns execution results and monitoring URL
5. **Cleanup**: The ephemeral environment is not stored permanently in the backend
**Benefits of ephemeral deployment + run:**
- **Development efficiency**: No separate permanent deployment step required
- **Always current**: Uses your latest code changes without polluting the backend
- **Clean development**: Ephemeral environments don't clutter your task registry
- **Integrated workflow**: Single command for complete development cycle
## Running deployed tasks
For production workflows or when you want to use stable deployed versions, you can run tasks that have been **permanently deployed** with `flyte deploy` without triggering any deployment process.
### Programmatic
```python
import flyte
flyte.init_from_config()
# Method 1: Using remote task reference
deployed_task = flyte.remote.Task.get("my_env.my_task", version="v1.0.0")
result = flyte.run(deployed_task, name="World")
# Method 2: Get latest version
deployed_task = flyte.remote.Task.get("my_env.my_task", auto_version="latest")
result = flyte.run(deployed_task, name="World")
```
### CLI
```bash
flyte run deployed-task my_env.my_task --name "World"
```
With a specific project and domain:
```bash
flyte run --project prod --domain production deployed-task my_env.my_task --batch_size 1000
```
**Task reference format:** `{environment_name}.{task_name}`
- `environment_name`: The `name` property of your `TaskEnvironment`
- `task_name`: The function name of your task
>[!NOTE]
> When you deploy a task environment with `flyte deploy`, you specify the `TaskEnvironment` by the variable to which it is assigned.
> Once deployed, you refer to it by its `name` property.
**Benefits of running deployed tasks:**
- **Performance**: No deployment overhead, faster execution startup
- **Stability**: Uses tested, stable versions of your code
- **Production safety**: Isolated from local development changes
- **Version control**: Explicit control over which code version runs
## Local execution
For development, debugging, and testing, you can run tasks locally on your machine without any backend interaction.
### Programmatic
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
# Method 1: No client configured (defaults to local)
result = flyte.run(my_task, name="World")
# Method 2: Explicit local mode
flyte.init_from_config() # Client configured
result = flyte.with_runcontext(mode="local").run(my_task, name="World")
```
### CLI
```bash
flyte run --local my_example.py my_task --name "World"
```
With development data:
```bash
flyte run --local data_pipeline.py process_data --input_path "/local/data" --debug true
```
**Benefits of local execution:**
- **Rapid development**: Instant feedback without network latency
- **Debugging**: Full access to local debugging tools
- **Offline development**: Works without backend connectivity
- **Resource efficiency**: Uses local compute resources
## Running tasks through the Union UI
If you are running your Flyte code on a Union backend, the UI provides an interactive way to run deployed tasks with form-based input and real-time monitoring.
### Accessing task execution in the Union UI
1. **Navigate to tasks**: Go to your project β domain β Tasks section
2. **Select task**: Choose the task environment and specific task
3. **Launch execution**: Click "Launch" to open the execution form
4. **Provide inputs**: Fill in task parameters through the web interface
5. **Monitor progress**: Watch real-time execution progress and logs
**UI execution benefits:**
- **User-friendly**: No command-line expertise required
- **Visual monitoring**: Real-time progress visualization
- **Input validation**: Built-in parameter validation and type checking
- **Execution history**: Easy access to previous runs and results
- **Sharing**: Shareable execution URLs for collaboration
Here is a short video demonstrating task execution through the Union UI:
πΊ [Watch on YouTube](https://www.youtube.com/watch?v=id="8jbau9yGoDg)
## Execution flow and architecture
### Fast registration architecture
Flyte v2 uses "fast registration" to enable rapid development cycles:
#### How it works
1. **Container images** contain the runtime environment and dependencies
2. **Code bundles** contain your Python source code (stored separately)
3. **At runtime**: Code bundles are downloaded and injected into running containers
#### Benefits
- **Rapid iteration**: Update code without rebuilding images
- **Resource efficiency**: Share images across multiple deployments
- **Version flexibility**: Run different code versions with same base image
- **Caching optimization**: Separate caching for images vs. code
#### When code gets injected
At task execution time, the fast registration process follows these steps:
1. **Container starts** with the base image containing runtime environment and dependencies
2. **Code bundle download**: The Flyte agent downloads your Python code bundle from storage
3. **Code extraction**: The code bundle is extracted and mounted into the running container
4. **Task execution**: Your task function executes with the injected code
### Ephemeral preparation logic
When using ephemeral deploy + run mode, Flyte determines whether temporary preparation is needed:
```mermaid
graph TD
A[flyte run command] --> B{Need preparation?}
B -->|Yes| C[Ephemeral preparation]
B -->|No| D[Use cached preparation]
C --> E[Execute task]
D --> E
E --> F[Cleanup ephemeral environment]
```
### Execution modes comparison
| Mode | Deployment | Performance | Use Case | Code Version |
|------|------------|-------------|-----------|--------------|
| Ephemeral Deploy + Run | Ephemeral (temporary) | Medium | Development, testing | Latest local |
| Run Deployed | None (uses permanent deployment) | Fast | Production, stable runs | Deployed version |
| Local | None | Variable | Development, debugging | Local |
| UI | None | Fast | Interactive, collaboration | Deployed version |
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/interacting-with-runs ===
# Interact with runs and actions
When a task is launched, the resulting execution is called a **run**.
Because tasks typically call other tasks, a run will almost always involve multiple sub-task executions. Each such execution is called an **action**.
Through the Flyte SDK and CLI, you can interact with the run and its actions to monitor progress, retrieve results, and access data. This section explains how to work with runs and actions programmatically and through the CLI.
## Understanding runs and actions
Runs are not declared explicitly in the code of the entry point task.
Instead, they are simply a result of the task being invoked in a specific way:
* User with `flyte run`
* User via the UI
* Other code calling `flyte.run()`
* [Trigger](../task-configuration/triggers)
When a task is invoked in one of these ways, it creates a run to represent the execution of that task and all its nested tasks, considered together.
Each task execution within that run is represented by an **action**.
The entry point task execution is represented by the main action (usually called `a0`), and then every nested call of one task from another creates an additional action.
```mermaid
graph TD
A[Run] --> B[Action a0 - Main task]
B --> C[Action a1 - Nested task]
B --> D[Action a2 - Nested task]
D --> E[Action a3 - Deeply nested task]
```
Because what constitutes a run depends only on how a task is invoked, the same task can execute as a deeply nested action in one run and the main action in another run.
Unlike Flyte 1, there is no explicit `@workflow` construct in Flyte 2; instead, "workflows" are defined implicitly by the structure of task composition and the entry point chosen at runtime.
> [!NOTE]
> Despite there being no explicit `@workflow` decorator, you'll often see the assemblage of tasks referred to as a "workflow" in documentation and discussions. The top-most task in a run is sometimes referred to as the "parent", "driver", or "entry point" task of the "workflow".
> In these docs we will sometime use "workflow" informally to refer to the collection of tasks (considered statically) involved in a run.
### Key concepts
- **Attempts**: Each action can have multiple attempts due to retries. Retries occur for two reasons:
- User-configured retries for handling transient failures
- Automatic system retries for infrastructure issues
- **Phases**: Both runs and actions progress through phases (e.g., QUEUED, RUNNING, SUCCEEDED, FAILED) until reaching a terminal state
- **Durability**: Flyte is a durable execution engine, so every input, output, failure, and attempt is recorded for each action. All data is persisted, allowing you to retrieve information about runs and actions even after completion
## Working with runs
Runs are created when you execute tasks using `flyte run` or `flyte.run()`. For details on running tasks, see [how task run works](./how-task-run-works). To learn about running previously deployed remote tasks, see [remote tasks](../task-programming/remote-tasks).
### Retrieving a run
### Programmatic
Use `flyte.remote.Run.get()` to retrieve information about a run:
```python
import flyte
flyte.init_from_config()
# Get a run by name
run = flyte.remote.Run.get("my_run_name")
# Access basic information
print(run.url) # UI URL for the run
print(run.action.phase) # Phase of the main action
```
### CLI
Get a specific run:
```bash
flyte get run my_run_name
```
List all runs:
```bash
flyte get run
```
Use `--project` and `--domain` to scope results to a specific [project-domain pair](../projects-and-domains).
For all available options, see the [CLI reference](../../api-reference/flyte-cli#flyte-get-run).
### Watching run progress
Monitor a run as it progresses through phases:
```python
# Wait for run to complete
run = flyte.run(my_task, input_data="test")
run.wait() # Blocks until terminal state
# Check if done
if run.action.done():
print("Run completed!")
```
### Getting detailed run information
Use `flyte.remote.RunDetails` for comprehensive information including nested actions and metadata:
```python
run_details = flyte.remote.RunDetails.get(name="my_run_name")
# Access detailed information
print(run_details.pb2) # Full protobuf representation
```
## Working with actions
Actions represent individual task executions within a run. Each action has a unique identifier within its parent run.
### Retrieving an action
### Programmatic
```python
# Get a specific action by run name and action name
action = flyte.remote.Action.get(
run_name="my_run_name",
name="a0" # Main action
)
# Access action information
print(action.phase) # Current phase
print(action.task_name) # Task being executed
print(action.start_time) # Execution start time
```
### CLI
Get a specific action:
```bash
flyte get action my_run_name a0
```
List all actions for a run:
```bash
flyte get action my_run_name
```
For all available options, see the [CLI reference](../../api-reference/flyte-cli#flyte-get-action).
### Nested actions
Deeply nested actions are uniquely identified by their path under the run:
```python
# Get a nested action
nested_action = flyte.remote.Action.get(
run_name="my_run_name",
name="a1" # Nested action identifier
)
```
### Getting detailed action information
Use `flyte.remote.ActionDetails` for comprehensive action information:
```python
action_details = flyte.remote.ActionDetails.get(
run_name="my_run_name",
name="a0"
)
# Access detailed information
print(action_details.pb2) # Full protobuf representation
```
## Retrieving inputs and outputs
### Programmatic
Both `Run` and `Action` objects provide methods to retrieve inputs and outputs:
```python
run = flyte.remote.Run.get("my_run_name")
# Get inputs - returns ActionInputs (dict-like)
inputs = run.inputs()
print(inputs) # {"param_name": value, ...}
# Get outputs - returns tuple
outputs = run.outputs()
print(outputs) # (result1, result2, ...)
# Single output
single_output = outputs[0]
# No outputs are represented as (None,)
```
**Important notes:**
- **Inputs**: Returned as `flyte.remote.ActionInputs`, a dictionary with parameter names as keys and values as the actual data passed in
- **Outputs**: Always returned as `flyte.remote.ActionOutputs` tuple, even for single outputs or no outputs
- **No outputs**: Represented as `(None,)`
- **Availability**: Outputs are only available if the action completed successfully
- **Type safety**: Flyte's rich type system converts data to an intermediate representation, allowing retrieval even without the original dependencies installed
### CLI
Get inputs and outputs for a run:
```bash
flyte get io my_run_name
```
Get inputs and outputs for a specific action:
```bash
flyte get io my_run_name a1
```
For all available options, see the [CLI reference](../../api-reference/flyte-cli#flyte-get-io).
### Handling failures
If an action fails, outputs are not available, but you can retrieve error information:
```python
action = flyte.remote.Action.get(run_name="my_run_name", name="a0")
if action.phase == flyte.models.ActionPhase.FAILED:
# Outputs will raise an error
try:
outputs = action.outputs()
except RuntimeError as e:
print("Action failed, outputs not available")
# Get error details instead
action_details = flyte.remote.ActionDetails.get(
run_name="my_run_name",
name="a0"
)
print(action_details.pb2.error_info)
```
## Understanding data storage
Flyte handles different types of data differently, as explained in [data flow](../run-scaling/data-flow):
- **Parameterized data** (primitives, small objects): Returned directly in inputs/outputs
- **Large data** (files, directories, DataFrames, models): Stored in cloud storage (S3, GCS, Azure Blob Storage)
When you retrieve outputs containing large data, Flyte returns references rather than the actual data. To access the actual raw data, you need proper cloud storage permissions and configuration.
## Accessing large data from cloud storage
To download and work with files, directories, and DataFrames stored in cloud object storage, you must configure storage access with appropriate credentials.
### S3 storage access
To access data stored in Amazon S3:
**1. Set environment variables:**
```bash
export FLYTE_AWS_ACCESS_KEY_ID="your-access-key-id"
export FLYTE_AWS_SECRET_ACCESS_KEY="your-secret-access-key"
```
These are standard AWS credential environment variables that Flyte recognizes.
**2. Initialize Flyte with S3 storage configuration:**
```python
import flyte
import flyte.storage
# Auto-configure from environment variables
flyte.init_from_config(
storage=flyte.storage.S3.auto(region="us-east-2")
)
# Or provide credentials explicitly
flyte.init_from_config(
storage=flyte.storage.S3(
access_key_id="your-access-key-id",
secret_access_key="your-secret-access-key",
region="us-east-2"
)
)
```
**3. Access data from outputs:**
```python
run = flyte.remote.Run.get("my_run_name")
outputs = run.outputs()
# Outputs containing files, dataframes, etc. can now be downloaded
dataframe = outputs[0]
df = await dataframe.open(pd.DataFrame).all()
```
### GCS storage access
To access data stored in Google Cloud Storage:
**1. Set environment variables:**
```bash
export GOOGLE_APPLICATION_CREDENTIALS="/path/to/service-account-key.json"
```
This is the standard Google Cloud authentication method using service account credentials.
**2. Initialize Flyte with GCS storage configuration:**
```python
import flyte
import flyte.storage
# Auto-configure from environment
flyte.init_from_config(
storage=flyte.storage.GCS.auto()
)
# Or configure explicitly
flyte.init_from_config(
storage=flyte.storage.GCS()
)
```
**3. Access data from outputs:**
```python
run = flyte.remote.Run.get("my_run_name")
outputs = run.outputs()
# Download data as needed
file_output = outputs[0]
# Work with file output
```
### Azure Blob Storage access
To access data stored in Azure Blob Storage (ABFS):
**1. Set environment variables:**
For storage account key authentication:
```bash
export AZURE_STORAGE_ACCOUNT_NAME="your-storage-account"
export AZURE_STORAGE_ACCOUNT_KEY="your-account-key"
```
For service principal authentication:
```bash
export AZURE_TENANT_ID="your-tenant-id"
export AZURE_CLIENT_ID="your-client-id"
export AZURE_CLIENT_SECRET="your-client-secret"
export AZURE_STORAGE_ACCOUNT_NAME="your-storage-account"
```
**2. Initialize Flyte with Azure storage configuration:**
```python
import flyte
import flyte.storage
# Auto-configure from environment variables
flyte.init_from_config(
storage=flyte.storage.ABFS.auto()
)
# Or provide credentials explicitly
flyte.init_from_config(
storage=flyte.storage.ABFS(
account_name="your-storage-account",
account_key="your-account-key"
)
)
# Or use service principal
flyte.init_from_config(
storage=flyte.storage.ABFS(
account_name="your-storage-account",
tenant_id="your-tenant-id",
client_id="your-client-id",
client_secret="your-client-secret"
)
)
```
**3. Access data from outputs:**
```python
run = flyte.remote.Run.get("my_run_name")
outputs = run.outputs()
# Download data as needed
directory_output = outputs[0]
# Work with directory output
```
## Complete example
Here's a complete example showing how to launch a run and interact with it:
```python
import flyte
import flyte.storage
# Initialize with storage access
flyte.init_from_config(
storage=flyte.storage.S3.auto(region="us-east-2")
)
# Define and run a task
env = flyte.TaskEnvironment(name="data_processing")
@env.task
async def process_data(input_value: str) -> str:
return f"Processed: {input_value}"
# Launch the run
run = flyte.run(process_data, input_value="test_data")
# Monitor progress
print(f"Run URL: {run.url}")
run.wait()
# Check status
if run.action.done():
print(f"Run completed with phase: {run.action.phase}")
# Get inputs and outputs
inputs = run.inputs()
print(f"Inputs: {inputs}")
outputs = run.outputs()
print(f"Outputs: {outputs}")
# Access the result
result = outputs[0]
print(f"Result: {result}")
```
## API reference
### Key classes
- `flyte.remote.Run` - Represents a run with basic information
- `flyte.remote.RunDetails` - Detailed run information including all actions
- `flyte.remote.Action` - Represents an action with basic information
- `flyte.remote.ActionDetails` - Detailed action information including error details
- `flyte.remote.ActionInputs` - Dictionary-like object containing action inputs
- `flyte.remote.ActionOutputs` - Tuple containing action outputs
### CLI commands
For complete CLI documentation and all available options, see the [Flyte CLI reference](../../api-reference/flyte-cli):
- [`flyte get run`](../../api-reference/flyte-cli#flyte-get-run) - Get run information
- [`flyte get action`](../../api-reference/flyte-cli#flyte-get-action) - Get action information
- [`flyte get io`](../../api-reference/flyte-cli#flyte-get-io) - Get inputs and outputs
- [`flyte get logs`](../../api-reference/flyte-cli#flyte-get-logs) - Get action logs
### Storage configuration
- `flyte.storage.S3` - Amazon S3 configuration
- `flyte.storage.GCS` - Google Cloud Storage configuration
- `flyte.storage.ABFS` - Azure Blob Storage configuration
For more details on data flow and storage, see [data flow](../run-scaling/data-flow).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/work-with-local-data ===
# Work with local data
When running Flyte tasks that take inputs like DataFrames, files, or directories, data is passed between actions through the configured blob store. For details on how data flows through your workflows, see [data flow](../run-scaling/data-flow).
Flyte provides several built-in types for handling data:
- `flyte.io.DataFrame` for tabular data
- `flyte.io.File` for individual files
- `flyte.io.Dir` for directories
You can also create custom type extensions for specialized data types. See [custom types](../task-programming/handling-custom-types) for details.
## Local execution
One of the most powerful features of Flyte is the ability to work with data entirely locally, without creating a remote run. When you run tasks in local mode, all inputs, outputs, and intermediate data stay on your local machine.
```python
import flyte
env = flyte.TaskEnvironment(name="local_data")
@env.task
async def process_data(data: str) -> str:
return f"Processed: {data}"
# Run locally - no remote storage needed
run = flyte.with_runcontext(mode="local").run(process_data, data="test")
run.wait()
print(run.outputs()[0])
```
For more details on local execution, see [how task run works](./how-task-run-works#local-execution).
## Uploading local data to remote runs
When you want to send local data to a remote task, you need to upload it first. Flyte provides a secure data uploading system that handles this automatically. The same system used for [code bundling](./packaging) can upload files, DataFrames, and directories.
To upload local data, use the Flyte core representation for that type with the `from_local_sync()` method.
### Uploading DataFrames
Use `flyte.io.DataFrame.from_local_sync()` to upload a local DataFrame:
```python
from typing import Annotated
import pandas as pd
import flyte
import flyte.io
img = flyte.Image.from_debian_base()
img = img.with_pip_packages("pandas", "pyarrow")
env = flyte.TaskEnvironment(
"dataframe_usage",
image=img,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
@env.task
async def process_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Process a DataFrame and return the result."""
df["processed"] = True
return df
if __name__ == "__main__":
flyte.init_from_config()
# Create a local pandas DataFrame
local_df = pd.DataFrame({
"name": ["Alice", "Bob", "Charlie"],
"value": [10, 20, 30]
})
# Upload the local DataFrame for remote execution
uploaded_df = flyte.io.DataFrame.from_local_sync(local_df)
# Pass to a remote task
run = flyte.run(process_dataframe, df=uploaded_df)
print(f"Run URL: {run.url}")
run.wait()
print(run.outputs()[0])
```
### Uploading files
Use `flyte.io.File.from_local_sync()` to upload a local file:
```python
import tempfile
import flyte
from flyte.io import File
env = flyte.TaskEnvironment(name="file-local")
@env.task
async def process_file(file: File) -> str:
"""Read and process a file."""
async with file.open("rb") as f:
content = bytes(await f.read())
return content.decode("utf-8")
if __name__ == "__main__":
flyte.init_from_config()
# Create a temporary local file
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as temp:
temp.write("Hello, Flyte!")
temp_path = temp.name
# Upload the local file for remote execution
file = File.from_local_sync(temp_path)
# Pass to a remote task
run = flyte.run(process_file, file=file)
print(f"Run URL: {run.url}")
run.wait()
print(run.outputs()[0])
```
### Uploading directories
Use `flyte.io.Dir.from_local_sync()` to upload a local directory:
```python
import os
import tempfile
import flyte
from flyte.io import Dir
env = flyte.TaskEnvironment(name="dir-local")
@env.task
async def process_dir(dir: Dir) -> dict[str, str]:
"""Process a directory and return file contents."""
file_contents = {}
async for file in dir.walk(recursive=False):
if file.name.endswith(".py"):
async with file.open("rb") as f:
content = bytes(await f.read())
file_contents[file.name] = content.decode("utf-8")[:100]
return file_contents
if __name__ == "__main__":
flyte.init_from_config()
# Create a temporary directory with test files
with tempfile.TemporaryDirectory() as temp_dir:
for i in range(3):
with open(os.path.join(temp_dir, f"file{i}.py"), "w") as f:
f.write(f"print('Hello from file {i}!')")
# Upload the local directory for remote execution
dir = Dir.from_local_sync(temp_dir)
# Pass to a remote task
run = flyte.run(process_dir, dir=dir)
print(f"Run URL: {run.url}")
run.wait()
print(run.outputs()[0])
```
## Passing outputs between runs
If you're passing outputs from a previous run to a new run, no upload is needed. Flyte's data is represented using native references that point to storage locations, so passing them between runs works automatically:
```python
import flyte
flyte.init_from_config()
# Get outputs from a previous run
previous_run = flyte.remote.Run.get("my_previous_run")
previous_output = previous_run.outputs()[0] # Already a Flyte reference
# Pass directly to a new run - no upload needed
new_run = flyte.run(my_task, data=previous_output)
```
## Performance considerations
The `from_local_sync()` method uses HTTP to upload data. This is convenient but not the most performant option for large datasets.
**Best suited for:**
- Small to medium test datasets
- Development and debugging
- Quick prototyping
**For larger data uploads**, configure cloud storage access and use `flyte.storage` directly:
```python
import flyte
import flyte.storage
# Configure storage access
flyte.init_from_config(
storage=flyte.storage.S3.auto(region="us-east-2")
)
```
For details on configuring storage access, see [interact with runs and actions](./interacting-with-runs#accessing-large-data-from-cloud-storage).
## Summary
| Scenario | Approach |
|----------|----------|
| Local development and testing | Use local execution mode |
| Small test data to remote tasks | Use `from_local_sync()` |
| Passing data between runs | Pass outputs directly (automatic) |
| Large datasets | Configure `flyte.storage` for direct cloud access |
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/run-command-options ===
# Run command options
The `flyte run` command provides the following options:
**`flyte run [OPTIONS] |deployed_task `**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|--------------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to run tasks in |
| `--domain` | `-d` | text | *from config* | Domain to run tasks in |
| `--local` | | flag | `false` | Run the task locally |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--raw-data-path` | | text | | Override the output location for offloaded data types. |
| `--service-account` | | text | | Kubernetes service account. |
| `--name` | | text | | Name of the run. |
| `--follow` | `-f` | flag | `false` | Wait and watch logs for the parent action. |
| `--image` | | text | | Image to be used in the run (format: `name=uri`). |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable synchronization of local sys.path entries. |
| `--run-project` | | text | *from config* | Execute deployed task in this project (`deployed-task` only). |
| `--run-domain` | | text | *from config* | Execute deployed task in this domain (`deployed-task` only). |
## `--project`, `--domain`
**`flyte run --domain --project |deployed_task `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
flyte run my_example.py my_task
```
Specify a target project and domain:
```bash
flyte run --project my-project --domain development my_example.py my_task
```
## `--run-project`, `--run-domain`
**`flyte run --run-project --run-domain deployed-task `**
When using the `deployed-task` subcommand, `--run-project` and `--run-domain` specify the [project-domain pair](../projects-and-domains) in which to *execute* the task. This lets you run a deployed task in a different project or domain than the one configured in your `config.yaml`:
```bash
flyte run --run-project prod-project --run-domain production deployed-task my_env.my_task
```
If not provided, these default to the `task.project` and `task.domain` values in your configuration file. These options only apply to the `deployed-task` subcommand and are ignored for file-based runs.
## `--local`
**`flyte run --local `**
The `--local` option runs tasks locally instead of submitting them to the remote Flyte backend:
```bash
flyte run --local my_example.py my_task --input "test_data"
```
Compare with remote execution:
```bash
flyte run my_example.py my_task --input "test_data"
```
### When to use local execution
- **Development and testing**: Quick iteration without deployment overhead
- **Debugging**: Full access to local debugging tools and environment
- **Resource constraints**: When remote resources are unavailable or expensive
- **Data locality**: When working with large local datasets
## `--copy-style`
**`flyte run --copy-style [loaded_modules|all|none] `**
The `--copy-style` option controls code bundling for remote execution.
This applies to the ephemeral preparation step of the `flyte run` command and works similarly to `flyte deploy`:
Smart bundling (default) β includes only imported project modules:
```bash
flyte run --copy-style loaded_modules my_example.py my_task
```
Include all project files:
```bash
flyte run --copy-style all my_example.py my_task
```
No code bundling (task must be pre-deployed):
```bash
flyte run --copy-style none deployed_task my_deployed_task
```
### Copy style options
- **`loaded_modules` (default)**: Bundles only imported Python modules from your project
- **`all`**: Includes all files in the project directory
- **`none`**: No bundling; requires permanently deployed tasks
## `--root-dir`
**`flyte run --root-dir `**
Override the source directory for code bundling and import resolution:
Run from a monorepo root with a specific root directory:
```bash
flyte run --root-dir ./services/ml ./services/ml/my_example.py my_task
```
Handle cross-directory imports:
```bash
flyte run --root-dir .. my_example.py my_workflow
```
This applies to the ephemeral preparation step of the `flyte run` command.
It works identically to the `flyte deploy` command's `--root-dir` option.
## `--raw-data-path`
**`flyte run --raw-data-path `**
Override the default output location for offloaded data types (large objects, DataFrames, etc.):
Use a custom S3 location for large outputs:
```bash
flyte run --raw-data-path s3://my-bucket/custom-path/ my_example.py process_large_data
```
Use a local directory for development:
```bash
flyte run --local --raw-data-path ./output/ my_example.py my_task
```
### Use cases
- **Custom storage locations**: Direct outputs to specific S3 buckets or paths
- **Cost optimization**: Use cheaper storage tiers for temporary data
- **Access control**: Ensure outputs go to locations with appropriate permissions
- **Local development**: Store large outputs locally when testing
## `--service-account`
**`flyte run --service-account `**
Specify a Kubernetes service account for task execution:
```bash
flyte run --service-account ml-service-account my_example.py train_model
flyte run --service-account data-reader-sa my_example.py load_data
```
### Use cases
- **Cloud resource access**: Service accounts with permissions for S3, GCS, etc.
- **Security isolation**: Different service accounts for different workload types
- **Compliance requirements**: Enforcing specific identity and access policies
## `--name`
**`flyte run --name `**
Provide a custom name for the execution run:
```bash
flyte run --name "daily-training-run-2024-12-02" my_example.py train_model
flyte run --name "experiment-lr-0.01-batch-32" my_example.py hyperparameter_sweep
```
### Benefits of custom names
- **Easy identification**: Find specific runs in the Flyte console
- **Experiment tracking**: Include key parameters or dates in names
- **Automation**: Programmatically generate meaningful names for scheduled runs
## `--follow`
**`flyte run --follow `**
Wait and watch logs for the execution in real-time:
```bash
flyte run --follow my_example.py long_running_task
```
Combine with other options:
```bash
flyte run --follow --name "training-session" my_example.py train_model
```
### Behavior
- **Log streaming**: Real-time output from task execution
- **Blocking execution**: Command waits until task completes
- **Exit codes**: Returns appropriate exit code based on task success/failure
## `--image`
**`flyte run --image `**
Override container images during ephemeral preparation, same as the equivalent `flyte deploy` option:
Override a specific named image:
```bash
flyte run --image gpu=ghcr.io/org/gpu:v2.1 my_example.py gpu_task
```
Override the default image:
```bash
flyte run --image ghcr.io/org/custom:latest my_example.py my_task
```
Multiple image overrides:
```bash
flyte run \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py multi_env_workflow
```
### Image mapping formats
- **Named mapping**: `name=uri` overrides images created with `Image.from_ref_name("name")`
- **Default mapping**: `uri` overrides the default "auto" image
- **Multiple mappings**: Use multiple `--image` flags for different image references
## `--no-sync-local-sys-paths`
**`flyte run --no-sync-local-sys-paths `**
Disable synchronization of local `sys.path` entries to the remote execution environment during ephemeral preparation.
Identical to the `flyte deploy` command's `--no-sync-local-sys-paths` option:
```bash
flyte run --no-sync-local-sys-paths my_example.py my_task
```
This advanced option works identically to the deploy command equivalent, useful for:
- **Container isolation**: Prevent local development paths from affecting remote execution
- **Custom environments**: When containers have pre-configured Python paths
- **Security**: Avoiding exposure of local directory structures
## Task argument passing
Arguments are passed directly as function parameters:
CLI β arguments as flags:
```bash
flyte run my_file.py my_task --name "World" --count 5 --debug true
```
SDK β arguments as function parameters:
```python
result = flyte.run(my_task, name="World", count=5, debug=True)
```
## SDK options
The core `flyte run` functionality is also available programmatically through the `flyte.run()` function.
For SDK-level configuration of all run parameters (storage, caching, identity, logging, and more),
see [Run context](./run-context).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/how-task-deployment-works ===
# How task deployment works
In this section, we will take a deep dive into how the `flyte deploy` command and the `flyte.deploy()` SDK function work under the hood to deploy tasks to your Flyte backend.
When you perform a deployment, here's what happens:
## 1. Module loading and task environment discovery
In the first step, Flyte determines which files to load in order to search for task environments, based on the command line options provided:
### Single file (default)
```bash
flyte deploy my_example.py env
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated,
but only the one assigned to the variable `env` is selected for deployment.
### `--all` option
```bash
flyte deploy --all my_example.py
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated and selected for deployment.
- No specific variable name is required.
### `--recursive` option
```bash
flyte deploy --recursive ./directory
```
- The directory is recursively traversed and all Python files are executed and all `TaskEnvironment` objects are instantiated.
- All `TaskEnvironment` objects across all files are selected for deployment.
## 2. Task analysis and serialization
- For every task environment selected for deployment, all of its tasks are identified.
- Task metadata is extracted: parameter types, return types, and resource requirements.
- Each task is serialized into a Flyte `TaskTemplate`.
- Dependency graphs between environments are built (see below).
## 3. Task environment dependency resolution
In many cases, a task in one environment may invoke a task in another environment, establishing a dependency between the two environments.
For example, if `env_a` has a task that calls a task in `env_b`, then `env_a` depends on `env_b`.
This means that when deploying `env_a`, `env_b` must also be deployed to ensure that all tasks can be executed correctly.
To handle this, `TaskEnvironment`s can declare dependencies on other `TaskEnvironment`s using the `depends_on` parameter.
During deployment, the system performs the following steps to resolve these dependencies:
1. Starting with specified environment(s)
2. Recursively discovering all transitive dependencies
3. Including all dependencies in the deployment plan
4. Processing dependencies depth-first to ensure correct order
```python
# Define environments with dependencies
prep_env = flyte.TaskEnvironment(name="preprocessing")
ml_env = flyte.TaskEnvironment(name="ml_training", depends_on=[prep_env])
viz_env = flyte.TaskEnvironment(name="visualization", depends_on=[ml_env])
# Deploy only viz_env - automatically includes ml_env and prep_env
deployment = flyte.deploy(viz_env, version="v2.0.0")
# Or deploy multiple environments explicitly
deployment = flyte.deploy(data_env, ml_env, viz_env, version="v2.0.0")
```
For detailed information about working with multiple environments, see [Multiple Environments](../task-configuration/multiple-environments).
## 4. Code bundle creation and upload
Once the task environments and their dependencies are resolved, Flyte proceeds to package your code into a bundle based on the `copy_style` option:
### `--copy_style loaded_modules` (default)
This is the smart bundling approach that analyzes which Python modules were actually imported during the task environment discovery phase.
It examines the runtime module registry (`sys.modules`) and includes only those modules that meet specific criteria:
they must have source files located within your project directory (not in system locations like `site-packages`), and they must not be part of the Flyte SDK itself.
This selective approach results in smaller, faster-to-upload bundles that contain exactly the code needed to run your tasks, making it ideal for most development and production scenarios.
### `--copy_style all`
This comprehensive bundling strategy takes a directory-walking approach, recursively traversing your entire project directory and including every file it encounters.
Unlike the smart bundling that only includes imported Python modules, this method captures all project files regardless of whether they were imported during discovery.
This is particularly useful for projects that use dynamic imports, load configuration files or data assets at runtime, or have dependencies that aren't captured through normal Python import mechanisms.
### `--copy_style none`
This option completely skips code bundle creation, meaning no source code is packaged or uploaded to cloud storage.
When using this approach, you must provide an explicit version parameter since there's no code bundle to generate a version from.
This strategy is designed for scenarios where your code is already baked into custom container images, eliminating the need for separate code injection during task execution.
It results in the fastest deployment times but requires more complex image management workflows.
### `--root-dir` option
By default, Flyte uses your current working directory as the root for code bundling.
You can override this with `--root-dir` to specify a different base directory - particularly useful for monorepos or when deploying from subdirectories. This affects all copy styles: `loaded_modules` will look for imported modules relative to the root directory, `all` will walk the directory tree starting from the root, and the root directory setting works with any copy style. See the [Deploy command options](./deploy-command-options#--root-dir) for detailed usage examples.
After the code bundle is created (if applicable), it is uploaded to a cloud storage location (like S3 or GCS) accessible by your Flyte backend. It is now ready to be run.
## 5. Image building
If your `TaskEnvironment` specifies [custom images](../task-configuration/container-images), Flyte builds and pushes container images before deploying tasks.
The build process varies based on your configuration and backend type:
### Local image building
When `image.builder` is set to `local` in [your `config.yaml`](../connecting-to-a-cluster), images are built on your local machine using Docker. This approach:
- Requires Docker to be installed and running on your development machine
- Uses Docker BuildKit to build images from generated Dockerfiles or your custom Dockerfile
- Pushes built images to the container registry specified in your `Image` configuration
- Is the only option available for Flyte OSS instances
### Remote image building
When `image.builder` is set to `remote` in [your `config.yaml`](../connecting-to-a-cluster), images are built on cloud infrastructure. This approach:
- Builds images using Union's ImageBuilder service (currently only available for Union backends, not OSS Flyte)
- Requires no local Docker installation or configuration
- Can push to Union's internal registry or external registries you specify
- Provides faster, more consistent builds by leveraging cloud resources
> [!NOTE]
> Remote building is currently exclusive to Union backends. OSS Flyte installations must use `local`
## Understanding option relationships
It's important to understand how the various deployment options work together.
The **discovery options** (`--recursive` and `--all`) operate independently of the **bundling options** (`--copy-style`),
giving you flexibility in how you structure your deployments.
Environment discovery determines which files Flyte will examine to find `TaskEnvironment` objects,
while code bundling controls what gets packaged and uploaded for execution.
You can freely combine these approaches.
For example, discovering environments recursively across your entire project while using smart bundling to include only the necessary code modules.
When multiple environments are discovered, they all share the same code bundle, which is efficient for related services or components that use common dependencies:
```bash
flyte deploy --recursive --copy-style loaded_modules ./project
```
> [!NOTE]
> All discovered environments share the same code bundle.
For a full overview of all deployment options, see **Flyte CLI > flyte > flyte deploy**.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/deploy-command-options ===
# Deploy command options
The `flyte deploy` command provides extensive configuration options:
**`flyte deploy [OPTIONS] [TASK_ENV_VARIABLE]`**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|---------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to deploy to |
| `--domain` | `-d` | text | *from config* | Domain to deploy to |
| `--version` | | text | *auto-generated* | Explicit version tag for deployment |
| `--dry-run`/`--dryrun` | | flag | `false` | Preview deployment without executing |
| `--all` | | flag | `false` | Deploy all environments in specified path |
| `--recursive` | `-r` | flag | `false` | Deploy environments recursively in subdirectories |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--image` | | text | | Image URI mappings (format: `name=uri`) |
| `--ignore-load-errors` | `-i` | flag | `false` | Continue deployment despite module load failures |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable local `sys.path` synchronization |
## `--project`, `--domain`
**`flyte deploy --domain --project `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
flyte deploy my_example.py env
```
Specify a target project and domain:
```bash
flyte deploy --project my-project --domain development my_example.py env
```
## `--version`
**`flyte deploy --version `**
The `--version` option controls how deployed tasks are tagged and identified in the Flyte backend:
Auto-generated version (default):
```bash
flyte deploy my_example.py env
```
Explicit version:
```bash
flyte deploy --version v1.0.0 my_example.py env
```
> [!NOTE]
> An explicit version is required when using `--copy-style none`, since there is no code bundle to generate a hash from.
```bash
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
### When versions are used
- **Explicit versioning**: Provides human-readable task identification (e.g., `v1.0.0`, `prod-2024-12-01`)
- **Auto-generated versions**: When no version is specified, Flyte creates an MD5 hash from the code bundle, environment configuration, and image cache
- **Version requirement**: `copy-style none` mandates explicit versions since there's no code bundle to hash
- **Task referencing**: Versions enable precise task references in `flyte run deployed-task` and workflow invocations
## `--dry-run`
**`flyte deploy --dry-run `**
The `--dry-run` option allows you to preview what would be deployed without actually performing the deployment:
```bash
flyte deploy --dry-run my_example.py env
```
## `--all` and `--recursive`
**`flyte deploy --all `**
**`flyte deploy --recursive `**
Control which environments get discovered and deployed:
**Single environment (default):**
```bash
flyte deploy my_example.py env
```
**All environments in file:**
```bash
flyte deploy --all my_example.py
```
**Recursive directory deployment:**
```bash
flyte deploy --recursive ./src
```
Combine with comprehensive bundling:
```bash
flyte deploy --recursive --copy-style all ./project
```
## `--copy-style`
**`flyte deploy --copy_style [loaded_modules|all|none] `**
The `--copy-style` option controls what gets packaged:
### `--copy-style loaded_modules` (default)
```bash
flyte deploy --copy-style loaded_modules my_example.py env
```
- **Includes**: Only imported Python modules from your project
- **Excludes**: Site-packages, system modules, Flyte SDK
- **Best for**: Most projects (optimal size and speed)
### `--copy-style all`
```bash
flyte deploy --copy-style all my_example.py env
```
- **Includes**: All files in project directory
- **Best for**: Projects with dynamic imports or data files
### `--copy-style none`
```bash
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
- **Requires**: Explicit version parameter
- **Best for**: Pre-built container images with baked-in code
## `--root-dir`
**`flyte deploy --root-dir `**
The `--root-dir` option overrides the default source directory that Flyte uses as the base for code bundling and import resolution.
This is particularly useful for monorepos and projects with complex directory structures.
### Default behavior (without `--root-dir`)
- Flyte uses the current working directory as the root
- Code bundling starts from this directory
- Import paths are resolved relative to this location
### Common use cases
**Monorepos:**
Deploy a service from the monorepo root:
```bash
flyte deploy --root-dir ./services/ml ./services/ml/my_example.py env
```
Deploy from anywhere in the monorepo:
```bash
cd ./docs/
flyte deploy --root-dir ../services/ml ../services/ml/my_example.py env
```
**Cross-directory imports:**
When a workflow imports modules from sibling directories (e.g., `project/workflows/my_example.py` imports `project/src/utils.py`):
```bash
cd project/workflows/
flyte deploy --root-dir .. my_example.py env
```
**Working directory independence:**
```bash
flyte deploy --root-dir /path/to/project /path/to/project/my_example.py env
```
### How it works
1. **Code bundling**: Files are collected starting from `--root-dir` instead of the current working directory
2. **Import resolution**: Python imports are resolved relative to the specified root directory
3. **Path consistency**: Ensures the same directory structure in local and remote execution environments
4. **Dependency packaging**: Captures all necessary modules that may be located outside the workflow file's immediate directory
### Example with complex project structure
```
my-project/
βββ services/
β βββ ml/
β β βββ my_example.py # imports shared.utils
β βββ api/
βββ shared/
βββ utils.py
```
```bash
flyte deploy --root-dir ./my-project ./my-project/services/ml/my_example.py env
```
This ensures that both `services/ml/` and `shared/` directories are included in the code bundle, allowing the workflow to successfully import `shared.utils` during remote execution.
## `--image`
**`flyte deploy --image `**
The `--image` option allows you to override image URIs at deployment time without modifying your code. Format: `imagename=imageuri`
### Named image mappings
```bash
flyte deploy --image base=ghcr.io/org/base:v1.0 my_example.py env
```
Multiple named image mappings:
```bash
flyte deploy \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py env
```
### Default image mapping
```bash
flyte deploy --image ghcr.io/org/default:latest my_example.py env
```
### How it works
- Named mappings (e.g., `base=URI`) override images created with `Image.from_ref_name("base")`.
- Unnamed mappings (e.g., just `URI`) override the default "auto" image.
- Multiple `--image` flags can be specified.
- Mappings are resolved during the image building phase of deployment.
## `--ignore-load-errors`
**`flyte deploy --ignore-load-errors `**
The `--ignore-load-errors` option allows the deployment process to continue even if some modules fail to load during the environment discovery phase. This is particularly useful for large projects or monorepos where certain modules may have missing dependencies or other issues that prevent them from being imported successfully.
```bash
flyte deploy --recursive --ignore-load-errors ./large-project
```
## `--no-sync-local-sys-paths`
**`flyte deploy --no-sync-local-sys-paths `**
The `--no-sync-local-sys-paths` option disables the automatic synchronization of local `sys.path` entries to the remote container environment. This is an advanced option for specific deployment scenarios.
### Default behavior (path synchronization enabled)
- Flyte captures local `sys.path` entries that are under the root directory
- These paths are passed to the remote container via the `_F_SYS_PATH` environment variable
- At runtime, the remote container adds these paths to its `sys.path`, maintaining the same import environment
### When to disable path synchronization
```bash
flyte deploy --no-sync-local-sys-paths my_example.py env
```
### Use cases for disabling
- **Custom container images**: When your container already has the correct `sys.path` configuration
- **Conflicting path structures**: When local development paths would interfere with container paths
- **Security concerns**: When you don't want to expose local development directory structures
- **Minimal environments**: When you want precise control over what gets added to the container's Python path
### How it works
- **Enabled (default)**: Local paths like `./my_project/utils` get synchronized and added to remote `sys.path`
- **Disabled**: Only the container's native `sys.path` is used, along with the deployed code bundle
Most users should leave path synchronization enabled unless they have specific requirements for container path isolation or are using pre-configured container environments.
## SDK deployment options
The core deployment functionality is available programmatically through the `flyte.deploy()` function, though some CLI-specific options are not applicable:
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process_data(data: str) -> str:
return f"Processed: {data}"
if __name__ == "__main__":
flyte.init_from_config()
# Comprehensive deployment configuration
deployment = flyte.deploy(
env, # Environment to deploy
dryrun=False, # Set to True for dry run
version="v1.2.0", # Explicit version tag
copy_style="loaded_modules" # Code bundling strategy
)
print(f"Deployment successful: {deployment[0].summary_repr()}")
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/packaging ===
# Code packaging for remote execution
When you run Flyte tasks remotely, your code needs to be available in the execution environment. Flyte SDK provides two main approaches for packaging your code:
1. **Code bundling** - Bundle code dynamically at runtime
2. **Container-based deployment** - Embed code directly in container images
## Quick comparison
| Aspect | Code bundling | Container-based |
|--------|---------------|-----------------|
| **Speed** | Fast (no image rebuild) | Slower (requires image build) |
| **Best for** | Rapid development, iteration | Production, immutable deployments |
| **Code changes** | Immediate effect | Requires image rebuild |
| **Setup** | Automatic by default | Manual configuration needed |
| **Reproducibility** | Excellent (hash-based versioning) | Excellent (immutable images) |
| **Rollback** | Requires version control | Tag-based, straightforward |
---
## Code bundling
**Default approach** - Automatically bundles and uploads your code to remote storage at runtime.
### How it works
When you run `flyte run` or call `flyte.run()`, Flyte automatically:
1. **Scans loaded modules** from your codebase
2. **Creates a tarball** (gzipped, without timestamps for consistent hashing)
3. **Uploads to blob storage** (S3, GCS, Azure Blob)
4. **Deduplicates** based on content hashes
5. **Downloads in containers** at runtime
This process happens transparently - every container downloads and extracts the code bundle before execution.
> [!NOTE]
> Code bundling is optimized for speed:
> - Bundles are created without timestamps for consistent hashing
> - Identical code produces identical hashes, enabling deduplication
> - Only modified code triggers new uploads
> - Containers cache downloaded bundles
>
> **Reproducibility:** Flyte automatically versions code bundles based on content hash. The same code always produces the same hash, guaranteeing reproducibility without manual versioning. However, version control is still recommended for rollback capabilities.
### Automatic code bundling
**Default behavior** - Bundles all loaded modules automatically.
#### What gets bundled
Flyte includes modules that are:
- β **Loaded when environment is parsed** (imported at module level)
- β **Part of your codebase** (not system packages)
- β **Within your project directory**
- β **NOT lazily loaded** (imported inside functions)
- β **NOT system-installed packages** (e.g., from site-packages)
#### Example: Basic automatic bundling
```python
# app.py
import flyte
from my_module import helper # β Bundled automatically
env = flyte.TaskEnvironment(
name="default",
image=flyte.Image.from_debian_base().with_pip_packages("pandas", "numpy")
)
@env.task
def process_data(x: int) -> int:
# This import won't be bundled (lazy load)
from another_module import util # β Not bundled automatically
return helper.transform(x)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process_data, x=42)
print(run.url)
```
When you run this:
```bash
flyte run app.py process_data --x 42
```
Flyte automatically:
1. Bundles `app.py` and `my_module.py`
2. Preserves the directory structure
3. Uploads to blob storage
4. Makes it available in the remote container
#### Project structure example
```
my_project/
βββ app.py # Main entry point
βββ tasks/
β βββ __init__.py
β βββ data_tasks.py # Flyte tasks
β βββ ml_tasks.py
βββ utils/
βββ __init__.py
βββ preprocessing.py # Business logic
βββ models.py
```
```python
# app.py
import flyte
from tasks.data_tasks import load_data # β Bundled
from tasks.ml_tasks import train_model # β Bundled
# utils modules imported in tasks are also bundled
@flyte.task
def pipeline(dataset: str) -> float:
data = load_data(dataset)
accuracy = train_model(data)
return accuracy
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(pipeline, dataset="train.csv")
```
**All modules are bundled with their directory structure preserved.**
### Manual code bundling
Control exactly what gets bundled by configuring the copy style.
#### Copy styles
Three options available:
1. **`"auto"`** (default) - Bundle loaded modules only
2. **`"all"`** - Bundle everything in the working directory
3. **`"none"`** - Skip bundling entirely (requires code in container)
#### Using `copy_style="all"`
Bundle all files under your project directory:
```python
import flyte
flyte.init_from_config()
# Bundle everything in current directory
run = flyte.with_runcontext(copy_style="all").run(
my_task,
input_data="sample.csv"
)
```
Or via CLI:
```bash
flyte run --copy-style=all app.py my_task --input-data sample.csv
```
**Use when:**
- You have data files or configuration that tasks need
- You use dynamic imports or lazy loading
- You want to ensure all project files are available
#### Using `copy_style="none"`
Skip code bundling (see **Run and deploy tasks > Code packaging for remote execution > Container-based deployment**):
```python
run = flyte.with_runcontext(copy_style="none").run(my_task, x=10)
```
### Controlling the root directory
The `root_dir` parameter controls which directory serves as the bundling root.
#### Why root directory matters
1. **Determines what gets bundled** - All code paths are relative to root_dir
2. **Preserves import structure** - Python imports must match the bundle structure
3. **Affects path resolution** - Files and modules are located relative to root_dir
#### Setting root directory
##### Via CLI
```bash
flyte run --root-dir /path/to/project app.py my_task
```
##### Programmatically
```python
import pathlib
import flyte
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
#### Root directory use cases
##### Use case 1: Multi-module project
```
project/
βββ src/
β βββ workflows/
β β βββ pipeline.py
β βββ utils/
β βββ helpers.py
βββ config.yaml
```
```python
# src/workflows/pipeline.py
import pathlib
import flyte
from utils.helpers import process # Relative import from project root
# Set root to project root (not src/)
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
@flyte.task
def my_task():
return process()
```
**Root set to `project/` so imports like `from utils.helpers` work correctly.**
##### Use case 2: Shared utilities
```
workspace/
βββ shared/
β βββ common.py
βββ project/
βββ app.py
```
```python
# project/app.py
import flyte
import pathlib
from shared.common import shared_function # Import from parent directory
# Set root to workspace/ to include shared/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent
)
```
##### Use case 3: Monorepo
```
monorepo/
βββ libs/
β βββ data/
β βββ models/
βββ services/
βββ ml_service/
βββ workflows.py
```
```python
# services/ml_service/workflows.py
import flyte
import pathlib
from libs.data import loader # Import from monorepo root
from libs.models import predictor
# Set root to monorepo/ to include libs/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
```
#### Root directory best practices
1. **Set root_dir at project initialization** before importing any task modules
2. **Use absolute paths** with `pathlib.Path(__file__).parent` navigation
3. **Match your import structure** - if imports are relative to project root, set root_dir to project root
4. **Keep consistent** - use the same root_dir for both `flyte run` and `flyte.init()`
### Code bundling examples
#### Example: Standard Python package
```
my_package/
βββ pyproject.toml
βββ src/
β βββ my_package/
β βββ __init__.py
β βββ main.py
β βββ data/
β β βββ loader.py
β β βββ processor.py
β βββ models/
β βββ analyzer.py
```
```python
# src/my_package/main.py
import flyte
import pathlib
from my_package.data.loader import fetch_data
from my_package.data.processor import clean_data
from my_package.models.analyzer import analyze
env = flyte.TaskEnvironment(
name="pipeline",
image=flyte.Image.from_debian_base().with_uv_project(
pyproject_file=pathlib.Path(__file__).parent.parent.parent / "pyproject.toml"
)
)
@env.task
async def fetch_task(url: str) -> dict:
return await fetch_data(url)
@env.task
def process_task(raw_data: dict) -> list[dict]:
return clean_data(raw_data)
@env.task
def analyze_task(data: list[dict]) -> str:
return analyze(data)
if __name__ == "__main__":
import flyte.git
# Set root to project root for proper imports
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent.parent.parent
)
# All modules bundled automatically
run = flyte.run(analyze_task, data=[{"value": 1}, {"value": 2}])
print(f"Run URL: {run.url}")
```
**Run with:**
```bash
cd my_package
flyte run src/my_package/main.py analyze_task --data '[{"value": 1}]'
```
#### Example: Dynamic environment based on domain
```python
# environment_picker.py
import flyte
def create_env():
"""Create different environments based on domain."""
if flyte.current_domain() == "development":
return flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "dev", "DEBUG": "true"}
)
elif flyte.current_domain() == "staging":
return flyte.TaskEnvironment(
name="staging",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "staging", "DEBUG": "false"}
)
else: # production
return flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "production", "DEBUG": "false"},
resources=flyte.Resources(cpu="2", memory="4Gi")
)
env = create_env()
@env.task
async def process(n: int) -> int:
import os
print(f"Running in {os.getenv('ENV')} environment")
return n * 2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process, n=5)
print(run.url)
```
**Why this works:**
- `flyte.current_domain()` is set correctly when Flyte re-instantiates modules remotely
- Environment configuration is deterministic and reproducible
- Code automatically bundled with domain-specific settings
> [!NOTE]
> `flyte.current_domain()` only works after `flyte.init()` is called:
> - β Works with `flyte run` and `flyte deploy` (auto-initialize)
> - β Works in `if __name__ == "__main__"` after explicit `flyte.init()`
> - β Does NOT work at module level without initialization
### When to use code bundling
β **Use code bundling when:**
- Rapid development and iteration
- Frequently changing code
- Multiple developers testing changes
- Jupyter notebook workflows
- Quick prototyping and experimentation
β **Consider container-based instead when:**
- Need easy rollback to previous versions (container tags are simpler than finding git commits)
- Working with air-gapped environments (no blob storage access)
- Code changes require coordinated dependency updates
---
## Container-based deployment
**Advanced approach** - Embed code directly in container images for immutable deployments.
### How it works
Instead of bundling code at runtime:
1. **Build container image** with code copied inside
2. **Disable code bundling** with `copy_style="none"`
3. **Container has everything** needed at runtime
**Trade-off:** Every code change requires a new image build (slower), but provides complete reproducibility.
### Configuration
Three key steps:
#### 1. Set `copy_style="none"`
Disable runtime code bundling:
```python
flyte.with_runcontext(copy_style="none").run(my_task, n=10)
```
Or via CLI:
```bash
flyte run --copy-style=none app.py my_task --n 10
```
#### 2. Copy Code into Image
Use `Image.with_source_file()` or `Image.with_source_folder()`:
```python
import pathlib
import flyte
env = flyte.TaskEnvironment(
name="embedded",
image=flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
```
#### 3. Set Correct `root_dir`
Match your image copy configuration:
```python
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
### Image source copying methods
#### `with_source_file()` - Copy individual files
Copy a single file into the container:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path(__file__),
dst="/app/main.py"
)
```
**Use for:**
- Single-file workflows
- Copying configuration files
- Adding scripts to existing images
#### `with_source_folder()` - Copy directories
Copy entire directories into the container:
```python
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
dst="/app",
copy_contents_only=False # Copy folder itself
)
```
**Parameters:**
- `src`: Source directory path
- `dst`: Destination path in container (optional, defaults to workdir)
- `copy_contents_only`: If `True`, copies folder contents; if `False`, copies folder itself
##### `copy_contents_only=True` (Recommended)
Copies only the contents of the source folder:
```python
# Project structure:
# my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
# Container will have:
# /app/app.py
# /app/utils.py
# Set root_dir to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
##### `copy_contents_only=False`
Copies the folder itself with its name:
```python
# Project structure:
# workspace/
# βββ my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent, # Points to my_project/
copy_contents_only=False
)
# Container will have:
# /app/my_project/app.py
# /app/my_project/utils.py
# Set root_dir to parent to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
### Complete container-based example
```python
# full_build.py
import pathlib
import flyte
from dep import helper # Local module
# Configure environment with source copying
env = flyte.TaskEnvironment(
name="full_build",
image=flyte.Image.from_debian_base()
.with_pip_packages("numpy", "pandas")
.with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
@env.task
def square(x: int) -> int:
return x ** helper.get_exponent()
@env.task
def main(n: int) -> list[int]:
return list(flyte.map(square, range(n)))
if __name__ == "__main__":
import flyte.git
# Initialize with matching root_dir
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent
)
# Run with copy_style="none" and explicit version
run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0" # Explicit version for image tagging
).run(main, n=10)
print(f"Run URL: {run.url}")
run.wait()
```
**Project structure:**
```
project/
βββ full_build.py
βββ dep.py # Local dependency
βββ .flyte/
βββ config.yaml
```
**Run with:**
```bash
python full_build.py
```
This will:
1. Build a container image with `full_build.py` and `dep.py` embedded
2. Tag it as `v1.0.0`
3. Push to registry
4. Execute remotely without code bundling
### Using externally built images
When containers are built outside of Flyte (e.g., in CI/CD), use `Image.from_ref_name()`:
#### Step 1: Build your image externally
```dockerfile
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Copy your code
COPY src/ /app/
# Install dependencies
RUN pip install flyte pandas numpy
# Ensure flyte executable is available
RUN flyte --help
```
Build and push the image:
```bash
docker build -t myregistry.com/my-app:v1.2.3 .
docker push myregistry.com/my-app:v1.2.3
```
#### Step 2: Reference image by name
```python
# app.py
import flyte
env = flyte.TaskEnvironment(
name="external",
image=flyte.Image.from_ref_name("my-app-image") # Reference name
)
@env.task
def process(x: int) -> int:
return x * 2
if __name__ == "__main__":
flyte.init_from_config()
# Pass actual image URI at deploy/run time
run = flyte.with_runcontext(
copy_style="none",
images={"my-app-image": "myregistry.com/my-app:v1.2.3"}
).run(process, x=10)
```
Or via CLI:
```bash
flyte run \
--copy-style=none \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py process --x 10
```
**For deployment:**
```bash
flyte deploy \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py
```
#### Why use reference names?
1. **Decouples code from image URIs** - Change images without modifying code
2. **Supports multiple environments** - Different images for dev/staging/prod
3. **Integrates with CI/CD** - Build images in pipelines, reference in code
4. **Enables image reuse** - Multiple tasks can reference the same image
#### Example: Multi-environment deployment
```python
import flyte
import os
# Code references image by name
env = flyte.TaskEnvironment(
name="api",
image=flyte.Image.from_ref_name("api-service")
)
@env.task
def api_call(endpoint: str) -> dict:
# Implementation
return {"status": "success"}
if __name__ == "__main__":
flyte.init_from_config()
# Determine image based on environment
environment = os.getenv("ENV", "dev")
image_uri = {
"dev": "myregistry.com/api-service:dev",
"staging": "myregistry.com/api-service:staging",
"prod": "myregistry.com/api-service:v1.2.3"
}[environment]
run = flyte.with_runcontext(
copy_style="none",
images={"api-service": image_uri}
).run(api_call, endpoint="/health")
```
### Container-based best practices
1. **Always set explicit versions** when using `copy_style="none"`:
```python
flyte.with_runcontext(copy_style="none", version="v1.0.0")
```
2. **Match `root_dir` to `copy_contents_only`**:
- `copy_contents_only=True` β `root_dir=Path(__file__).parent`
- `copy_contents_only=False` β `root_dir=Path(__file__).parent.parent`
3. **Ensure `flyte` executable is in container** - Add to PATH or install flyte package
4. **Use `.dockerignore`** to exclude unnecessary files:
```
# .dockerignore
__pycache__/
*.pyc
.git/
.venv/
*.egg-info/
```
5. **Test containers locally** before deploying:
```bash
docker run -it myimage:latest /bin/bash
python -c "import mymodule" # Verify imports work
```
### When to use container-based deployment
β **Use container-based when:**
- Deploying to production
- Need immutable, reproducible environments
- Working with complex system dependencies
- Deploying to air-gapped or restricted environments
- CI/CD pipelines with automated builds
- Code changes are infrequent
β **Don't use container-based when:**
- Rapid development and frequent code changes
- Quick prototyping
- Interactive development (Jupyter notebooks)
- Learning and experimentation
---
## Choosing the right approach
### Decision tree
```
Are you iterating quickly on code?
ββ Yes β Use Code Bundling (Default)
β (Development, prototyping, notebooks)
β Both approaches are fully reproducible via hash/tag
ββ No β Do you need easy version rollback?
ββ Yes β Use Container-based
β (Production, CI/CD, straightforward tag-based rollback)
ββ No β Either works
(Code bundling is simpler, container-based for air-gapped)
```
### Hybrid approach
You can use different approaches for different tasks:
```python
import flyte
import pathlib
# Fast iteration for development tasks
dev_env = flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base().with_pip_packages("pandas")
# Code bundling (default)
)
# Immutable containers for production tasks
prod_env = flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base()
.with_pip_packages("pandas")
.with_source_folder(pathlib.Path(__file__).parent, copy_contents_only=True)
# Requires copy_style="none"
)
@dev_env.task
def experimental_task(x: int) -> int:
# Rapid development with code bundling
return x * 2
@prod_env.task
def stable_task(x: int) -> int:
# Production with embedded code
return x ** 2
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
# Use code bundling for dev task
dev_run = flyte.run(experimental_task, x=5)
# Use container-based for prod task
prod_run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0"
).run(stable_task, x=5)
```
---
## Troubleshooting
### Import errors
**Problem:** `ModuleNotFoundError` when task executes remotely
**Solutions:**
1. **Check loaded modules** - Ensure modules are imported at module level:
```python
# β Good - bundled automatically
from mymodule import helper
@flyte.task
def my_task():
return helper.process()
```
```python
# β Bad - not bundled (lazy load)
@flyte.task
def my_task():
from mymodule import helper
return helper.process()
```
2. **Verify `root_dir`** matches your import structure:
```python
# If imports are: from mypackage.utils import foo
# Then root_dir should be parent of mypackage/
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
3. **Use `copy_style="all"`** to bundle everything:
```bash
flyte run --copy-style=all app.py my_task
```
### Code changes not reflected
**Problem:** Remote execution uses old code despite local changes
> [!NOTE]
> This is rare with code bundling - Flyte automatically versions based on content hash, so code changes should be detected automatically. This issue typically occurs with caching problems or when using `copy_style="none"`.
**Solutions:**
1. **Use explicit version bump** (mainly for container-based deployments):
```python
run = flyte.with_runcontext(version="v2").run(my_task)
```
2. **Check if `copy_style="none"`** is set - this requires image rebuild:
```python
# If using copy_style="none", rebuild image
run = flyte.with_runcontext(
copy_style="none",
version="v2" # Bump version to force rebuild
).run(my_task)
```
### Files missing in container
**Problem:** Task can't find data files or configs
**Solutions:**
1. **Use `copy_style="all"`** to bundle all files:
```bash
flyte run --copy-style=all app.py my_task
```
2. **Copy files explicitly in image**:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path("config.yaml"),
dst="/app/config.yaml"
)
```
3. **Store data in remote storage** instead of bundling:
```python
@flyte.task
def my_task():
# Read from S3/GCS instead of local files
import flyte.io
data = flyte.io.File("s3://bucket/data.csv").open().read()
```
### Container build failures
**Problem:** Image build fails with `copy_style="none"`
**Solutions:**
1. **Check `root_dir` matches `copy_contents_only`**:
```python
# copy_contents_only=True
image = Image.from_debian_base().with_source_folder(
src=Path(__file__).parent,
copy_contents_only=True
)
flyte.init(root_dir=Path(__file__).parent) # Match!
```
2. **Ensure `flyte` executable available**:
```python
image = Image.from_debian_base() # Has flyte pre-installed
```
3. **Check file permissions** in source directory:
```bash
chmod -R +r project/
```
### Version conflicts
**Problem:** Multiple versions of same image causing confusion
**Solutions:**
1. **Use explicit versions**:
```python
run = flyte.with_runcontext(
copy_style="none",
version="v1.2.3" # Explicit, not auto-generated
).run(my_task)
```
2. **Clean old images**:
```bash
docker image prune -a
```
3. **Use semantic versioning** for clarity:
```python
version = "v1.0.0" # Major.Minor.Patch
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/deployment-patterns ===
# Deployment patterns
Once you understand the basics of task deployment, you can leverage various deployment patterns to handle different project structures, dependency management approaches, and deployment requirements. This section covers the most common patterns with practical examples.
## Overview of deployment patterns
Flyte supports multiple deployment patterns to accommodate different project structures and requirements:
1. ****Run and deploy tasks > Deployment patterns > Simple file deployment**** - Single file with tasks and environments
2. ****Run and deploy tasks > Deployment patterns > Custom Dockerfile deployment**** - Full control over container environment
3. ****Run and deploy tasks > Deployment patterns > PyProject package deployment**** - Structured Python packages with dependencies and async tasks
4. ****Run and deploy tasks > Deployment patterns > Package structure deployment**** - Organized packages with shared environments
5. ****Run and deploy tasks > Deployment patterns > Full build deployment**** - Complete code embedding in containers
6. ****Run and deploy tasks > Deployment patterns > Python path deployment**** - Multi-directory project structures
7. ****Run and deploy tasks > Deployment patterns > Dynamic environment deployment**** - Environment selection based on domain context
Each pattern serves specific use cases and can be combined as needed for complex projects.
## Simple file deployment
The simplest deployment pattern involves defining both your tasks and task environment in a single Python file. This pattern works well for:
- Prototyping and experimentation
- Simple tasks with minimal dependencies
- Educational examples and tutorials
### Example structure
```python
import flyte
env = flyte.TaskEnvironment(name="simple_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/simple_file.py*
### Deployment commands
Deploy the environment:
```bash
flyte deploy my_example.py env
```
Run the task ephemerally:
```bash
flyte run my_example.py my_task --name "World"
```
### When to use
- Quick prototypes and experiments
- Single-purpose scripts
- Learning Flyte basics
- Tasks with no external dependencies
## Custom Dockerfile deployment
When you need full control over the container environment, you can specify a custom Dockerfile. This pattern is ideal for:
- Complex system dependencies
- Specific OS or runtime requirements
- Custom base images
- Multi-stage builds
### Example structure
```dockerfile
# syntax=docker/dockerfile:1.5
FROM ghcr.io/astral-sh/uv:0.8 as uv
FROM python:3.12-slim-bookworm
USER root
# Copy in uv so that later commands don't have to mount it in
COPY --from=uv /uv /usr/bin/uv
# Configure default envs
ENV UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
VIRTUALENV=/opt/venv \
UV_PYTHON=/opt/venv/bin/python \
PATH="/opt/venv/bin:$PATH"
# Create a virtualenv with the user specified python version
RUN uv venv /opt/venv --python=3.12
WORKDIR /root
# Install dependencies
COPY requirements.txt .
RUN uv pip install --pre -r /root/requirements.txt
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/Dockerfile*
```python
from pathlib import Path
import flyte
env = flyte.TaskEnvironment(
name="docker_env",
image=flyte.Image.from_dockerfile(
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent / "Dockerfile",
registry="ghcr.io/flyteorg",
name="docker_env_image",
),
)
@env.task
def main(x: int) -> int:
return x * 2
if __name__ == "__main__":
import flyte.git
flyte.init_from_config(flyte.git.config_from_root())
run = flyte.run(main, x=10)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/dockerfile_env.py*
### Alternative: Dockerfile in different directory
You can also reference Dockerfiles from subdirectories:
```python
from pathlib import Path
import flyte
env = flyte.TaskEnvironment(
name="docker_env_in_dir",
image=flyte.Image.from_dockerfile(
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent.parent / "Dockerfile.workdir",
registry="ghcr.io/flyteorg",
name="docker_env_image",
),
)
@env.task
def main(x: int) -> int:
return x * 2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, x=10)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/src/docker_env_in_dir.py*
```dockerfile
# syntax=docker/dockerfile:1.5
FROM ghcr.io/astral-sh/uv:0.8 as uv
FROM python:3.12-slim-bookworm
USER root
# Copy in uv so that later commands don't have to mount it in
COPY --from=uv /uv /usr/bin/uv
# Configure default envs
ENV UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
VIRTUALENV=/opt/venv \
UV_PYTHON=/opt/venv/bin/python \
PATH="/opt/venv/bin:$PATH"
# Create a virtualenv with the user specified python version
RUN uv venv /opt/venv --python=3.12
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN uv pip install --pre -r /app/requirements.txt
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/Dockerfile.workdir*
### Key considerations
- **Path handling**: Use `Path(__file__).parent` for relative Dockerfile paths
```python
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent / "Dockerfile"
```
- **Registry configuration**: Specify a registry for image storage
- **Build context**: The directory containing the Dockerfile becomes the build context
- **Flyte installation**: Ensure Flyte is installed in the container and available on `$PATH`
```dockerfile
# Install Flyte in your Dockerfile
RUN pip install flyte
```
- **Dependencies**: Include all application requirements in the Dockerfile or requirements.txt
### When to use
- Need specific system packages or tools
- Custom base image requirements
- Complex installation procedures
- Multi-stage build optimization
## PyProject package deployment
For structured Python projects with proper package management, use the PyProject pattern. This approach demonstrates a **realistic Python project structure** that provides:
- Proper dependency management with `pyproject.toml` and external packages like `httpx`
- Clean separation of business logic and Flyte tasks across multiple modules
- Professional project structure with `src/` layout
- Async task execution with API calls and data processing
- Entrypoint patterns for both command-line and programmatic execution
### Example structure
```
pyproject_package/
βββ pyproject.toml # Project metadata and dependencies
βββ README.md # Documentation
βββ src/
βββ pyproject_package/
βββ __init__.py # Package initialization
βββ main.py # Entrypoint script
βββ data/
β βββ __init__.py
β βββ loader.py # Data loading utilities (no Flyte)
β βββ processor.py # Data processing utilities (no Flyte)
βββ models/
β βββ __init__.py
β βββ analyzer.py # Analysis utilities (no Flyte)
βββ tasks/
βββ __init__.py
βββ tasks.py # Flyte task definitions
```
### Business logic modules
The business logic is completely separate from Flyte and can be used independently:
#### Data Loading (`data/loader.py`)
```python
import json
from pathlib import Path
from typing import Any
import httpx
async def fetch_data_from_api(url: str) -> list[dict[str, Any]]:
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=10.0)
response.raise_for_status()
return response.json()
def load_local_data(file_path: str | Path) -> dict[str, Any]:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
with path.open("r") as f:
return json.load(f)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/data/loader.py*
#### Data Processing (`data/processor.py`)
```python
import asyncio
from typing import Any
from pydantic import BaseModel, Field, field_validator
class DataItem(BaseModel):
id: int = Field(gt=0, description="Item ID must be positive")
value: float = Field(description="Item value")
category: str = Field(min_length=1, description="Item category")
@field_validator("category")
@classmethod
def category_must_be_lowercase(cls, v: str) -> str:
return v.lower()
def clean_data(raw_data: dict[str, Any]) -> dict[str, Any]:
# Remove None values
cleaned = {k: v for k, v in raw_data.items() if v is not None}
# Validate items if present
if "items" in cleaned:
validated_items = []
for item in cleaned["items"]:
try:
validated = DataItem(**item)
validated_items.append(validated.model_dump())
except Exception as e:
print(f"Skipping invalid item {item}: {e}")
continue
cleaned["items"] = validated_items
return cleaned
def transform_data(data: dict[str, Any]) -> list[dict[str, Any]]:
items = data.get("items", [])
# Add computed fields
transformed = []
for item in items:
transformed_item = {
**item,
"value_squared": item["value"] ** 2,
"category_upper": item["category"].upper(),
}
transformed.append(transformed_item)
return transformed
async def aggregate_data(items: list[dict[str, Any]]) -> dict[str, Any]:
# Simulate async processing
await asyncio.sleep(0.1)
aggregated: dict[str, dict[str, Any]] = {}
for item in items:
category = item["category"]
if category not in aggregated:
aggregated[category] = {
"count": 0,
"total_value": 0.0,
"values": [],
}
aggregated[category]["count"] += 1
aggregated[category]["total_value"] += item["value"]
aggregated[category]["values"].append(item["value"])
# Calculate averages
for category, v in aggregated.items():
total = v["total_value"]
count = v["count"]
v["average_value"] = total / count if count > 0 else 0.0
return {"categories": aggregated, "total_items": len(items)}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/data/processor.py*
#### Analysis (`models/analyzer.py`)
```python
from typing import Any
import numpy as np
def calculate_statistics(data: list[dict[str, Any]]) -> dict[str, Any]:
if not data:
return {
"count": 0,
"mean": 0.0,
"median": 0.0,
"std_dev": 0.0,
"min": 0.0,
"max": 0.0,
}
values = np.array([item["value"] for item in data])
stats = {
"count": len(values),
"mean": float(np.mean(values)),
"median": float(np.median(values)),
"std_dev": float(np.std(values)),
"min": float(np.min(values)),
"max": float(np.max(values)),
"percentile_25": float(np.percentile(values, 25)),
"percentile_75": float(np.percentile(values, 75)),
}
return stats
def generate_report(stats: dict[str, Any]) -> str:
report_lines = [
"=" * 60,
"DATA ANALYSIS REPORT",
"=" * 60,
]
# Basic statistics section
if "basic" in stats:
basic = stats["basic"]
report_lines.extend(
[
"",
"BASIC STATISTICS:",
f" Count: {basic.get('count', 0)}",
f" Mean: {basic.get('mean', 0.0):.2f}",
f" Median: {basic.get('median', 0.0):.2f}",
f" Std Dev: {basic.get('std_dev', 0.0):.2f}",
f" Min: {basic.get('min', 0.0):.2f}",
f" Max: {basic.get('max', 0.0):.2f}",
f" 25th %ile: {basic.get('percentile_25', 0.0):.2f}",
f" 75th %ile: {basic.get('percentile_75', 0.0):.2f}",
]
)
# Category aggregations section
if "aggregated" in stats and "categories" in stats["aggregated"]:
categories = stats["aggregated"]["categories"]
total_items = stats["aggregated"].get("total_items", 0)
report_lines.extend(
[
"",
"CATEGORY BREAKDOWN:",
f" Total Items: {total_items}",
"",
]
)
for category, cat_stats in sorted(categories.items()):
report_lines.extend(
[
f" Category: {category.upper()}",
f" Count: {cat_stats.get('count', 0)}",
f" Total Value: {cat_stats.get('total_value', 0.0):.2f}",
f" Average Value: {cat_stats.get('average_value', 0.0):.2f}",
"",
]
)
report_lines.append("=" * 60)
return "\n".join(report_lines)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/models/analyzer.py*
These modules demonstrate:
- **No Flyte dependencies** - can be tested and used independently
- **Pydantic models** for data validation with custom validators
- **Async patterns** with proper context managers and error handling
- **NumPy integration** for statistical calculations
- **Professional error handling** with timeouts and validation
### Flyte orchestration layer
The Flyte tasks orchestrate the business logic with proper async execution:
```python
import pathlib
from typing import Any
import flyte
from pyproject_package.data import loader, processor
from pyproject_package.models import analyzer
UV_PROJECT_ROOT = pathlib.Path(__file__).parent.parent.parent.parent
env = flyte.TaskEnvironment(
name="data_pipeline",
image=flyte.Image.from_debian_base().with_uv_project(pyproject_file=UV_PROJECT_ROOT / "pyproject.toml"),
resources=flyte.Resources(memory="512Mi", cpu="500m"),
)
@env.task
async def fetch_task(url: str) -> list[dict[str, Any]]:
print(f"Fetching data from: {url}")
data = await loader.fetch_data_from_api(url)
print(f"Fetched {len(data)} top-level keys")
return data
@env.task
async def process_task(raw_data: dict[str, Any]) -> list[dict[str, Any]]:
print("Cleaning data...")
cleaned = processor.clean_data(raw_data)
print("Transforming data...")
transformed = processor.transform_data(cleaned)
print(f"Processed {len(transformed)} items")
return transformed
@env.task
async def analyze_task(processed_data: list[dict[str, Any]]) -> str:
print("Aggregating data...")
aggregated = await processor.aggregate_data(processed_data)
print("Calculating statistics...")
stats = analyzer.calculate_statistics(processed_data)
print("Generating report...")
report = analyzer.generate_report({"basic": stats, "aggregated": aggregated})
print("\n" + report)
return report
@env.task
async def pipeline(api_url: str) -> str:
# Chain tasks together
raw_data = await fetch_task(url=api_url)
processed_data = await process_task(raw_data=raw_data[0])
report = await analyze_task(processed_data=processed_data)
return report
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/tasks/tasks.py*
### Entrypoint configuration
The main entrypoint demonstrates proper initialization and execution patterns:
```python
import pathlib
import flyte
from pyproject_package.tasks.tasks import pipeline
def main():
# Initialize Flyte connection
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
# Example API URL with mock data
# In a real scenario, this would be a real API endpoint
example_url = "https://jsonplaceholder.typicode.com/posts"
# For demonstration, we'll use mock data instead of the actual API
# to ensure the example works reliably
print("Starting data pipeline...")
print(f"Target API: {example_url}")
# To run remotely, uncomment the following:
run = flyte.run(pipeline, api_url=example_url)
print(f"\nRun Name: {run.name}")
print(f"Run URL: {run.url}")
run.wait()
if __name__ == "__main__":
main()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/main.py*
### Dependencies and configuration
```toml
[project]
name = "pyproject-package"
version = "0.1.0"
description = "Example Python package with Flyte tasks and modular business logic"
readme = "README.md"
authors = [
{ name = "Ketan Umare", email = "kumare3@users.noreply.github.com" }
]
requires-python = ">=3.10"
dependencies = [
"flyte>=2.0.0b52",
"httpx>=0.27.0",
"numpy>=1.26.0",
"pydantic>=2.0.0",
]
[project.scripts]
run-pipeline = "pyproject_package.main:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/pyproject.toml*
### Key features
- **Async task chains**: Tasks can be chained together with proper async/await patterns
- **External dependencies**: Demonstrates integration with external libraries (`httpx`, `pyyaml`)
- **uv integration**: Uses `.with_uv_project()` for dependency management
- **Resource specification**: Shows how to set memory and CPU requirements
- **Proper error handling**: Includes timeout and error handling in API calls
### Key learning points
1. **Separation of concerns**: Business logic (`data/`, `models/`) separate from orchestration (`main.py`)
2. **Reusable code**: Non-Flyte modules can be tested independently and reused
3. **Async support**: Demonstrates async Flyte tasks for I/O-bound operations
4. **Dependency management**: Shows how external packages integrate with Flyte
5. **Realistic structure**: Mirrors real-world Python project organization
6. **Entrypoint script**: Shows how to create runnable entry points
### Usage patterns
**Run locally:**
```bash
python -m pyproject_package.main
```
**Deploy to Flyte:**
```bash
flyte deploy .
```
**Run remotely:**
```bash
python -m pyproject_package.main # Uses remote execution
```
### What this example demonstrates
- Multiple files and modules in a package
- Async Flyte tasks with external API calls
- Separation of business logic from orchestration
- External dependencies (`httpx`, `numpy`, `pydantic`)
- **Data validation with Pydantic models** for robust data processing
- **Professional error handling** with try/catch for data validation
- **Timeout configuration** for external API calls (`timeout=10.0`)
- **Async context managers** for proper resource management (`async with httpx.AsyncClient()`)
- Entrypoint script pattern with `project.scripts`
- Realistic project structure with `src/` layout
- Task chaining and data flow
- How non-Flyte code integrates with Flyte tasks
### When to use
- Production-ready, maintainable projects
- Projects requiring external API integration
- Complex data processing pipelines
- Team development with proper separation of concerns
- Applications needing async execution patterns
## Package structure deployment
For organizing Flyte workflows in a package structure with shared task environments and utilities, use this pattern. It's particularly useful for:
- Multiple workflows that share common environments and utilities
- Organized code structure with clear module boundaries
- Projects where you want to reuse task environments across workflows
### Example structure
```
lib/
βββ __init__.py
βββ workflows/
βββ __init__.py
βββ workflow1.py # First workflow
βββ workflow2.py # Second workflow
βββ env.py # Shared task environment
βββ utils.py # Shared utilities
```
### Key concepts
- **Shared environments**: Define task environments in `env.py` and import across workflows
- **Utility modules**: Common functions and utilities shared between workflows
- **Root directory handling**: Use `--root-dir` flag for proper Python path configuration
### Running with root directory
When running workflows with a package structure, specify the root directory:
```bash
flyte run --root-dir . lib/workflows/workflow1.py process_workflow
flyte run --root-dir . lib/workflows/workflow2.py math_workflow --n 6
```
### How `--root-dir` works
The `--root-dir` flag automatically configures the Python path (`sys.path`) to ensure:
1. **Local execution**: Package imports work correctly when running locally
2. **Consistent behavior**: Same Python path configuration locally and at runtime
3. **No manual PYTHONPATH**: Eliminates need to manually export environment variables
4. **Runtime packaging**: Flyte packages and copies code correctly to execution environment
5. **Runtime consistency**: The same package structure is preserved in the runtime container
### Alternative: Using a Python project
For larger projects, create a proper Python project with `pyproject.toml`:
```toml
# pyproject.toml
[project]
name = "lib"
version = "0.1.0"
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
```
Then install in editable mode:
```bash
pip install -e .
```
After installation, you can run workflows without `--root-dir`:
```bash
flyte run lib/workflows/workflow1.py process_workflow
```
However, for deployment and remote execution, still use `--root-dir` for consistency:
```bash
flyte run --root-dir . lib/workflows/workflow1.py process_workflow
flyte deploy --root-dir . lib/workflows/workflow1.py
```
### When to use
- Multiple related workflows in one project
- Shared task environments and utilities
- Team projects with multiple contributors
- Applications requiring organized code structure
- Projects that benefit from proper Python packaging
## Full build deployment
When you need complete reproducibility and want to embed all code directly in the container image, use the full build pattern. This disables Flyte's fast deployment system in favor of traditional container builds.
### Overview
By default, Flyte uses a fast deployment system that:
- Creates a tar archive of your files
- Skips the full image build and push process
- Provides faster iteration during development
However, sometimes you need to **completely embed your code into the container image** for:
- Full reproducibility with immutable container images
- Environments where fast deployment isn't available
- Production deployments with all dependencies baked in
- Air-gapped or restricted deployment environments
### Key configuration
```python
import pathlib
from dep import foo
import flyte
env = flyte.TaskEnvironment(
name="full_build",
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True # Avoid nested folders
),
)
@env.task
def square(x) -> int:
return x ** foo()
@env.task
def main(n: int) -> list[int]:
return list(flyte.map(square, range(n)))
if __name__ == "__main__":
# copy_contents_only=True requires root_dir=parent, False requires root_dir=parent.parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
run = flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/full_build/main.py*
### Local dependency example
The main.py file imports from a local dependency that gets included in the build:
```python
def foo() -> int:
return 1
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/full_build/dep.py*
### Critical configuration components
1. **Set `copy_style` to `"none"`**:
```python
flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
```
This disables Flyte's fast deployment system and forces a full container build.
2. **Set a custom version**:
```python
flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
```
The `version` parameter should be set to a desired value (not auto-generated) for consistent image tagging.
3. **Configure image source copying**:
```python
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True
)
```
Use `.with_source_folder()` to specify what code to copy into the container.
4. **Set `root_dir` correctly**:
```python
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
- If `copy_contents_only=True`: Set `root_dir` to the source folder (contents are copied)
- If `copy_contents_only=False`: Set `root_dir` to parent directory (folder is copied)
### Configuration options
#### Option A: Copy Folder Structure
```python
# Copies the entire folder structure into the container
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=False # Default
)
# When copy_contents_only=False, set root_dir to parent.parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
#### Option B: Copy Contents Only (Recommended)
```python
# Copies only the contents of the folder (flattens structure)
# This is useful when you want to avoid nested folders - for example all your code is in the root of the repo
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True
)
# When copy_contents_only=True, set root_dir to parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
### Version management best practices
When using `copy_style="none"`, always specify an explicit version:
- Use semantic versioning: `"v1.0.0"`, `"v1.1.0"`
- Use build numbers: `"build-123"`
- Use git commits: `"abc123"`
Avoid auto-generated versions to ensure reproducible deployments.
### Performance considerations
- **Full builds take longer** than fast deployment
- **Container images will be larger** as they include all source code
- **Better for production** where immutability is important
- **Use during development** when testing the full deployment pipeline
### When to use
β **Use full build when:**
- Deploying to production environments
- Need immutable, reproducible container images
- Working with complex dependency structures
- Deploying to air-gapped or restricted environments
- Building CI/CD pipelines
β **Don't use full build when:**
- Rapid development and iteration
- Working with frequently changing code
- Development environments where speed matters
- Simple workflows without complex dependencies
### Troubleshooting
**Common issues:**
1. **Import errors**: Check your `root_dir` configuration matches `copy_contents_only`
2. **Missing files**: Ensure all dependencies are in the source folder
3. **Version conflicts**: Use explicit, unique version strings
4. **Build failures**: Check that the base image has all required system dependencies
**Debug tips:**
- Add print statements to verify file paths in containers
- Use `docker run -it /bin/bash` to inspect built images
- Check Flyte logs for build errors and warnings
- Verify that relative imports work correctly in the container context
## Python path deployment
For projects where workflows are separated from business logic across multiple directories, use the Python path pattern with proper `root_dir` configuration.
### Example structure
```
pythonpath/
βββ workflows/
β βββ workflow.py # Flyte workflow definitions
βββ src/
β βββ my_module.py # Business logic modules
βββ run.sh # Execute from project root
βββ run_inside_folder.sh # Execute from workflows/ directory
```
### Implementation
```python
import pathlib
from src.my_module import env, say_hello
import flyte
env = flyte.TaskEnvironment(
name="workflow_env",
depends_on=[env],
)
@env.task
async def greet(name: str) -> str:
return await say_hello(name)
if __name__ == "__main__":
current_dir = pathlib.Path(__file__).parent
flyte.init_from_config(root_dir=current_dir.parent)
r = flyte.run(greet, name="World")
print(r.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pythonpath/workflows/workflow.py*
```python
import flyte
env = flyte.TaskEnvironment(
name="my_module",
)
@env.task
async def say_hello(name: str) -> str:
return f"Hello, {name}!"
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pythonpath/src/my_module.py*
### Task environment dependencies
Note how the workflow imports both the task environment and the task function:
```python
from src.my_module import env, say_hello
env = flyte.TaskEnvironment(
name="workflow_env",
depends_on=[env], # Depends on the imported environment
)
```
This pattern allows sharing task environments across modules while maintaining proper dependency relationships.
### Key considerations
- **Import resolution**: `root_dir` enables proper module imports across directories
- **File packaging**: Flyte packages all files starting from `root_dir`
- **Execution flexibility**: Works regardless of where you execute the script
- **PYTHONPATH handling**: Different behavior for CLI vs direct Python execution
### CLI vs Direct Python execution
#### Using Flyte CLI with `--root-dir` (Recommended)
When using `flyte run` with `--root-dir`, you don't need to export PYTHONPATH:
```bash
flyte run --root-dir . workflows/workflow.py greet --name "World"
```
The CLI automatically:
- Adds the `--root-dir` location to `sys.path`
- Resolves all imports correctly
- Packages files from the root directory for remote execution
#### Using Python directly
When running Python scripts directly, you must set PYTHONPATH manually:
```bash
PYTHONPATH=.:$PYTHONPATH python workflows/workflow.py
```
This is because:
- Python doesn't automatically know about your project structure
- You need to explicitly tell Python where to find your modules
- The `root_dir` parameter handles remote packaging, not local path resolution
### Best practices
1. **Always set `root_dir`** when workflows import from multiple directories
2. **Use pathlib** for cross-platform path handling
3. **Set `root_dir` to your project root** to ensure all dependencies are captured
4. **Test both execution patterns** to ensure deployment works from any directory
### Common pitfalls
- **Forgetting `root_dir`**: Results in import errors during remote execution
- **Wrong `root_dir` path**: May package too many or too few files
- **Not setting PYTHONPATH when using Python directly**: Use `flyte run --root-dir .` instead
- **Mixing execution methods**: If you use `flyte run --root-dir .`, you don't need PYTHONPATH
### When to use
- Legacy projects with established directory structures
- Separation of concerns between workflows and business logic
- Multiple workflow definitions sharing common modules
- Projects with complex import hierarchies
**Note:** This pattern is an escape hatch for larger projects where code organization requires separating workflows from business logic. Ideally, structure projects with `pyproject.toml` for cleaner dependency management.
## Dynamic environment deployment
For environments that need to change based on deployment context (development vs production), use dynamic environment selection based on Flyte domains.
### Domain-based environment selection
Use `flyte.current_domain()` to deterministically create different task environments based on the deployment domain:
```python
# NOTE: flyte.init() invocation at the module level is strictly discouraged.
# At runtime, Flyte controls initialization and configuration files are not present.
import os
import flyte
def create_env():
if flyte.current_domain() == "development":
return flyte.TaskEnvironment(name="dev", image=flyte.Image.from_debian_base(), env_vars={"MY_ENV": "dev"})
return flyte.TaskEnvironment(name="prod", image=flyte.Image.from_debian_base(), env_vars={"MY_ENV": "prod"})
env = create_env()
@env.task
async def my_task(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['MY_ENV']}", flush=True)
return n + 1
@env.task
async def entrypoint(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['MY_ENV']}", flush=True)
return await my_task(n)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments/environment_picker.py*
### Why this pattern works
**Environment reproducibility in local and remote clusters is critical.** Flyte re-instantiates modules in remote clusters, so `current_domain()` will be set correctly based on where the code executes.
β **Do use `flyte.current_domain()`** - Flyte automatically sets this based on the execution context
β **Don't use environment variables directly** - They won't yield correct results unless manually passed to the downstream system
### How it works
1. Flyte sets the domain context when initializing
2. `current_domain()` returns the domain string (e.g., "development", "staging", "production")
3. Your code deterministically configures resources based on this domain
4. When Flyte executes remotely, it re-instantiates modules with the correct domain context
5. The same environment configuration logic runs consistently everywhere
### Important constraints
`flyte.current_domain()` only works **after** `flyte.init()` is called:
- β Works with `flyte run` and `flyte deploy` CLI commands (they init automatically)
- β Works when called from `if __name__ == "__main__"` after explicit `flyte.init()`
- β Does NOT work at module level without initialization
**Critical:** `flyte.init()` invocation at the module level is **strictly discouraged**. The reason is that at runtime, Flyte controls the initialization and configuration files are not present at runtime.
### Alternative: Environment variable approach
For cases where you need to pass domain information as environment variables to the container runtime, use this approach:
```python
import os
import flyte
def create_env(domain: str):
# Pass domain as environment variable so tasks can see which domain they're running in
if domain == "development":
return flyte.TaskEnvironment(name="dev", image=flyte.Image.from_debian_base(), env_vars={"DOMAIN_NAME": domain})
return flyte.TaskEnvironment(name="prod", image=flyte.Image.from_debian_base(), env_vars={"DOMAIN_NAME": domain})
env = create_env(os.getenv("DOMAIN_NAME", "development"))
@env.task
async def my_task(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['DOMAIN_NAME']}", flush=True)
return n + 1
@env.task
async def entrypoint(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['DOMAIN_NAME']}", flush=True)
return await my_task(n)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(entrypoint, n=5)
print(r.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments_with_envvars/environment_picker.py*
#### Key differences from domain-based approach
- **Environment variable access**: The domain name is available inside tasks via `os.environ['DOMAIN_NAME']`
- **External control**: Can be controlled via system environment variables before execution
- **Runtime visibility**: Tasks can inspect which environment they're running in during execution
- **Default fallback**: Uses `"development"` as default when `DOMAIN_NAME` is not set
#### Usage with environment variables
Set the environment variable and run:
```bash
export DOMAIN_NAME=production
flyte run environment_picker.py entrypoint --n 5
```
Or set it inline:
```bash
DOMAIN_NAME=development flyte run environment_picker.py entrypoint --n 5
```
#### When to use environment variables vs domain-based
**Use environment variables when:**
- Tasks need runtime access to environment information
- External systems set environment configuration
- You need flexibility to override environment externally
- Debugging requires visibility into environment selection
**Use domain-based approach when:**
- Environment selection should be automatic based on Flyte domain
- You want tighter integration with Flyte's domain system
- No need for runtime environment inspection within tasks
You can vary multiple aspects based on context:
- **Base images**: Different images for dev vs prod
- **Environment variables**: Configuration per environment
- **Resource requirements**: Different CPU/memory per domain
- **Dependencies**: Different package versions
- **Registry settings**: Different container registries
### Usage patterns
```bash
flyte run environment_picker.py entrypoint --n 5
flyte deploy environment_picker.py
```
For programmatic usage, ensure proper initialization:
```python
import flyte
flyte.init_from_config()
from environment_picker import entrypoint
if __name__ == "__main__":
r = flyte.run(entrypoint, n=5)
print(r.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments/main.py*
### When to use dynamic environments
**General use cases:**
- Multi-environment deployments (dev/staging/prod)
- Different resource requirements per environment
- Environment-specific dependencies or settings
- Context-sensitive configuration needs
**Domain-based approach for:**
- Automatic environment selection tied to Flyte domains
- Simpler configuration without external environment variables
- Integration with Flyte's built-in domain system
**Environment variable approach for:**
- Runtime visibility into environment selection within tasks
- External control over environment configuration
- Debugging and logging environment-specific behavior
- Integration with external deployment systems that set environment variables
## Best practices
### Project organization
1. **Separate concerns**: Keep business logic separate from Flyte task definitions
2. **Use proper imports**: Structure projects for clean import patterns
3. **Version control**: Include all necessary files in version control
4. **Documentation**: Document deployment requirements and patterns
### Image management
1. **Registry configuration**: Use consistent registry settings across environments
2. **Image tagging**: Use meaningful tags for production deployments
3. **Base image selection**: Choose appropriate base images for your needs
4. **Dependency management**: Keep container images lightweight but complete
### Configuration management
1. **Root directory**: Set `root_dir` appropriately for your project structure
2. **Path handling**: Use `pathlib.Path` for cross-platform compatibility
3. **Environment variables**: Use environment-specific configurations
4. **Secrets management**: Handle sensitive data appropriately
### Development workflow
1. **Local testing**: Test tasks locally before deployment
2. **Incremental development**: Use `flyte run` for quick iterations
3. **Production deployment**: Use `flyte deploy` for permanent deployments
4. **Monitoring**: Monitor deployed tasks and environments
## Choosing the right pattern
| Pattern | Use Case | Complexity | Best For |
|---------|----------|------------|----------|
| Simple file | Quick prototypes, learning | Low | Single tasks, experiments |
| Custom Dockerfile | System dependencies, custom environments | Medium | Complex dependencies |
| PyProject package | Professional projects, async pipelines | Medium-High | Production applications |
| Package structure | Multiple workflows, shared utilities | Medium | Organized team projects |
| Full build | Production, reproducibility | High | Immutable deployments |
| Python path | Legacy structures, separated concerns | Medium | Existing codebases |
| Dynamic environment | Multi-environment, domain-aware deployments | Medium | Context-aware deployments |
Start with simpler patterns and evolve to more complex ones as your requirements grow. Many projects will combine multiple patterns as they scale and mature.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/run-context ===
# Run context
Every Flyte run has a **run context** β a set of invocation-time parameters that control where the run executes, where its outputs are stored, how caching behaves, and more.
There are two sides to run context:
- **Write side**: `flyte.with_runcontext()` β set run parameters before the run starts (programmatic) or via CLI flags.
- **Read side**: `flyte.ctx()` β access run parameters inside a running task.
## Configuring a run with `flyte.with_runcontext()`
`flyte.with_runcontext()` returns a runner object. Call `.run(task, ...)` on it to start the run with the specified context:
```
import flyte
env = flyte.TaskEnvironment("run-context-example")
@env.task
async def process(n: int) -> int:
return n * 2
@env.task
async def root() -> int:
return await process(21)
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(
name="my-run",
project="my-project",
domain="development",
).run(root)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/run-context/run_context.py*
All parameters are optional. Unset parameters inherit from the configuration file (`config.yaml`) or system defaults.
### Execution target
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `mode` | `"local"` \| `"remote"` \| `"hybrid"` | *from config* | Where the run executes. `"remote"` runs on the Flyte backend; `"local"` runs in-process. |
| `project` | `str` | *from config* | Project to run in. |
| `domain` | `str` | *from config* | Domain to run in (e.g. `"development"`, `"production"`). |
| `name` | `str` | *auto-generated* | Custom name for the run, visible in the UI. |
| `version` | `str` | *from code bundle* | Version string for the ephemeral task deployment. |
| `queue` | `str` | *from config* | Cluster queue to schedule tasks on. |
| `interruptible` | `bool` | *per-task setting* | Override the interruptible setting for all tasks in the run. `True` allows spot/preemptible instances; `False` forces non-interruptible instances. |
### Storage
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `raw_data_path` | `str` | *from config* | Storage prefix for offloaded data types ([Files](../task-programming/files-and-directories), [Dirs](../task-programming/files-and-directories), [DataFrames](../task-programming/dataframes), checkpoints). Accepts `s3://`, `gs://`, or local paths. |
| `run_base_dir` | `str` | *auto-generated* | Base directory for run metadata passed between tasks. Distinct from `raw_data_path`. |
To direct all task outputs to a specific bucket for a run:
```
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(
# Store all task outputs in a dedicated S3 prefix for this run
raw_data_path="s3://my-bucket/runs/experiment-42/",
).run(root)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/run-context/run_context.py*
The equivalent CLI flag is `--raw-data-path`. See [Run command options](./run-command-options#--raw-data-path) for CLI usage.
### Caching
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `overwrite_cache` | `bool` | `False` | Re-execute all tasks even if a cached result exists, and overwrite the cache with new results. |
| `disable_run_cache` | `bool` | `False` | Skip cache lookups and writes entirely for this run. |
| `cache_lookup_scope` | `"global"` \| ... | `"global"` | Scope for cache lookups. |
### Identity and resources
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `service_account` | `str` | *from config* | Kubernetes service account for task pods. |
| `env_vars` | `Dict[str, str]` | `None` | Additional environment variables to inject into task containers. |
| `labels` | `Dict[str, str]` | `None` | Kubernetes labels to apply to task pods. |
| `annotations` | `Dict[str, str]` | `None` | Kubernetes annotations to apply to task pods. |
### Logging
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `log_level` | `int` | *from config* | Python log level (e.g. `logging.DEBUG`). |
| `log_format` | `"console"` \| ... | `"console"` | Log output format. |
| `reset_root_logger` | `bool` | `False` | If `True`, preserve the root logger unchanged. |
### Code bundling
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `copy_style` | `"loaded_modules"` \| `"all"` \| `"none"` | `"loaded_modules"` | Code bundling strategy. See [Run command options](./run-command-options#--copy-style). |
| `dry_run` | `bool` | `False` | Build and upload the code bundle without executing the run. |
| `copy_bundle_to` | `Path` | `None` | When `dry_run=True`, copy the bundle to this local path. |
| `interactive_mode` | `bool` | *auto-detected* | Override interactive mode detection (set automatically for Jupyter notebooks). |
| `preserve_original_types` | `bool` | `False` | Keep native DataFrame types (e.g. `pd.DataFrame`) rather than converting to `flyte.io.DataFrame` when deserializing outputs. |
### Context propagation
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `custom_context` | `Dict[str, str]` | `None` | Metadata propagated through the entire task hierarchy. Readable inside any task via `flyte.ctx().custom_context`. See [Custom context](../task-programming/custom-context). |
---
## Reading context inside a task with `flyte.ctx()`
Inside a running task, `flyte.ctx()` returns a `TaskContext` object with information about the current execution. Outside of a task, it returns `None`.
```
@env.task
async def inspect_context() -> str:
ctx = flyte.ctx()
action = ctx.action
return (
f"run={action.run_name}, "
f"action={action.name}, "
f"mode={ctx.mode}, "
f"in_cluster={ctx.is_in_cluster()}"
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/run-context/run_context.py*
### `TaskContext` fields
| Field | Type | Description |
|-------|------|-------------|
| `action` | `ActionID` | Identity of this specific action (task invocation) within the run. |
| `mode` | `"local"` \| `"remote"` \| `"hybrid"` | Execution mode of the current run. |
| `version` | `str` | Version of the deployed task code bundle. |
| `raw_data_path` | `str` | Storage prefix where offloaded outputs are written. |
| `run_base_dir` | `str` | Base directory for run metadata. |
| `custom_context` | `Dict[str, str]` | Propagated context metadata from `with_runcontext()`. |
| `disable_run_cache` | `bool` | Whether run caching is disabled for this run. |
| `is_in_cluster()` | method | Returns `True` when `mode == "remote"`. Useful for branching local/remote behavior. |
### `ActionID` fields
The `ctx.action` object identifies this specific task invocation:
| Field | Type | Description |
|-------|------|-------------|
| `name` | `str` | Unique identifier for this action. |
| `run_name` | `str` | Name of the parent run (defaults to `name` if not set). |
| `project` | `str \| None` | Project the action runs in. |
| `domain` | `str \| None` | Domain the action runs in. |
| `org` | `str \| None` | Organization. |
### Naming external resources
`ctx.action.run_name` is useful for tying external tool runs (experiment trackers, dashboards) to the corresponding Flyte run:
```
import wandb # type: ignore[import]
@env.task
async def train_model(epochs: int) -> float:
ctx = flyte.ctx()
# Use run_name to tie the W&B run to this Flyte run
run = wandb.init(
project="my-project",
name=ctx.action.run_name,
config={"epochs": epochs},
)
# ... training logic ...
loss = 0.42
run.log({"loss": loss})
run.finish()
return loss
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/run-context/run_context.py*
This ensures that when you look up a run in Weights & Biases (or any other tool), its name matches what you see in the Flyte UI.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling ===
# Scale your runs
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This guide helps you understand and optimize the performance of your Flyte workflows. Whether you're building latency-sensitive applications or high-throughput data pipelines, these docs will help you make the right architectural choices.
## Understanding Flyte execution
Before optimizing performance, it's important to understand how Flyte executes your workflows:
- ****Scale your runs > Data flow****: Learn how data moves between tasks, including inline vs. reference data types, caching mechanisms, and storage configuration.
- ****Scale your runs > Life of a run****: Understand what happens when you invoke `flyte.run()`, from code analysis and image building to task execution and state management.
## Performance optimization
Once you understand the fundamentals, dive into performance tuning:
- ****Scale your runs > Scale your workflows****: A comprehensive guide to optimizing workflow performance, covering latency vs. throughput, task overhead analysis, batching strategies, reusable containers, and more.
## Key concepts for scaling
When scaling your workflows, keep these principles in mind:
1. **Task overhead matters**: The overhead of creating a task (uploading data, enqueuing, creating containers) should be much smaller than the task runtime.
2. **Batch for throughput**: For large-scale data processing, batch multiple items into single tasks to reduce overhead.
3. **Reusable containers**: Eliminate container startup overhead and enable concurrent execution with reusable containers.
4. **Traces for lightweight ops**: Use traces instead of tasks for lightweight operations that need checkpointing.
5. **Limit fanout**: Keep the total number of actions per run below 50k (target 10k-20k for best performance).
6. **Choose the right data types**: Use reference types (files, directories, DataFrames) for large data and inline types for small data.
For detailed guidance on each of these topics, see **Scale your runs > Scale your workflows**.
## Subpages
- **Scale your runs > Data flow**
- **Scale your runs > Life of a run**
- **Scale your runs > Scale your workflows**
- **Scale your runs > Maximize GPU utilization for batch inference**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/data-flow ===
# Data flow
Understanding how data flows between tasks is critical for optimizing workflow performance in Flyte. Tasks take inputs and produce outputs, with data flowing seamlessly through your workflow using an efficient transport layer.
## Overview
Flyte tasks are run to completion. Each task takes inputs and produces exactly one output. Even if multiple instances run concurrently (such as in retries), only one output will be accepted. This deterministic data flow model provides several key benefits:
1. **Reduced boilerplate**: Automatic handling of files, DataFrames, directories, custom types, data classes, Pydantic models, and primitive types without manual serialization.
2. **Type safety**: Optional type annotations enable deeper type understanding, automatic UI form generation, and runtime type validation.
3. **Efficient transport**: Data is passed by reference (files, directories, DataFrames) or by value (primitives) based on type.
4. **Durable storage**: All data is stored durably and accessible through APIs and the UI.
5. **Caching support**: Efficient caching using shallow immutable references for referenced data.
## Data types and transport
Flyte handles different data types with different transport mechanisms:
### Passed by reference
These types are not copied but passed as references to storage locations:
- **Files**: `flyte.io.File`
- **Directories**: `flyte.io.Dir`
- **Dataframes**: `flyte.io.DataFrame`, `pd.DataFrame`, `pl.DataFrame`, etc.
Dataframes are automatically converted to Parquet format and read using Apache Arrow for zero-copy reads. Use `flyte.io.DataFrame` for lazy materialization to any supported type like pandas or polars. [Learn more about the Flyte Dataframe type](../../user-guide/task-programming/dataframes)
### Passed by value (inline I/O)
Primitive and structured types are serialized and passed inline:
| Type Category | Examples | Serialization |
|--------------|----------|---------------|
| **Primitives** | `int`, `float`, `str`, `bool`, `None` | MessagePack |
| **Time types** | `datetime.datetime`, `datetime.date`, `datetime.timedelta` | MessagePack |
| **Collections** | `list`, `dict`, `tuple` | MessagePack |
| **Data structures** | data classes, Pydantic `BaseModel` | MessagePack |
| **Enums** | `enum.Enum` subclasses | MessagePack |
| **Unions** | `Union[T1, T2]`, `Optional[T]` | MessagePack |
| **Protobuf** | `google.protobuf.Message` | Binary |
Flyte uses efficient MessagePack serialization for most types, providing compact binary representation with strong type safety.
> [!NOTE]
> If type annotations are not used, or if `typing.Any` or unrecognized types are used, data will be pickled. By default, pickled objects smaller than 10KB are passed inline, while larger pickled objects are automatically passed as a file. Pickling allows for progressive typing but should be used carefully.
## Task execution and data flow
### Input download
When a task starts:
1. **Inline inputs download**: The task downloads inline inputs from the configured Flyte object store.
2. **Size limits**: By default, inline inputs are limited to 10MB, but this can be adjusted using `flyte.TaskEnvironment`'s `max_inline_io` parameter.
3. **Memory consideration**: Inline data is materialized in memory, so adjust your task resources accordingly.
4. **Reference materialization**: Reference data (files, directories) is passed using special types in `flyte.io`. Dataframes are automatically materialized if using `pd.DataFrame`. Use `flyte.io.DataFrame` to avoid automatic materialization.
### Output upload
When a task returns data:
1. **Inline data**: Uploaded to the Flyte object store configured at the organization, project, or domain level.
2. **Reference data**: Stored in the same metadata store by default, or configured using `flyte.with_runcontext(raw_data_storage=...)`.
3. **Separate prefixes**: Each task creates one output per retry attempt in separate prefixes, making data incorruptible by design.
## Task-to-task data flow
When a task invokes downstream tasks:
1. **Input recording**: The input to the downstream task is recorded to the object store.
2. **Reference upload**: All referenced objects are uploaded (if not already present).
3. **Task invocation**: The downstream task is invoked on the remote server.
4. **Parallel execution**: When multiple tasks are invoked in parallel using `flyte.map` or `asyncio`, inputs are written in parallel.
5. **Storage layer**: Data writing uses the `flyte.storage` layer, backed by the Rust-based `object-store` crate and optionally `fsspec` plugins.
6. **Output download**: Once the downstream task completes, inline outputs are downloaded and returned to the calling task.
## Caching and data hashing
Understanding how Flyte caches data is essential for performance optimization.
### Cache key computation
A cache hit occurs when the following components match:
- **Task name**: The fully-qualified task name
- **Computed input hash**: Hash of all inputs (excluding `ignored_inputs`)
- **Task interface hash**: Hash of input and output types
- **Task config hash**: Hash of task configuration
- **Cache version**: User-specified or automatically computed
### Inline data caching
All inline data is cached using a consistent hashing system. The cache key is derived from the data content.
### Reference data hashing
Reference data (files, directories) is hashed shallowly by default using the hash of the storage location. You can customize hashing:
- Use `flyte.io.File.new_remote()` or `flyte.io.File.from_existing_remote()` with custom hash functions or values.
- Provide explicit hash values for deep content hashing if needed.
### Cache control
Control caching behavior using `flyte.with_runcontext`:
- **Scope**: Set `cache_lookup_scope` to `"global"` or `"project/domain"`.
- **Disable cache**: Set `overwrite_cache=True` to force re-execution.
For more details on caching configuration, see [Caching](../task-configuration/caching).
## Traces and data flow
When using [traces](../task-programming/traces), the data flow behavior is different:
1. **Full execution first**: The trace is fully executed before inputs and outputs are recorded.
2. **Checkpoint behavior**: Recording happens like a checkpoint at the end of trace execution.
3. **Streaming iterators**: The entire output is buffered and recorded after the stream completes. Buffering is pass-through, allowing caller functions to consume output while buffering.
4. **Chained traces**: All traces are recorded after the last one completes consumption.
5. **Same process with `asyncio`**: Traces run within the same Python process and support `asyncio` parallelism, so failures can be retried, effectively re-running the trace.
6. **Lightweight overhead**: Traces only have the overhead of data storage (no task orchestration overhead).
> [!NOTE]
> Traces are not a substitute for tasks if you need caching. Tasks provide full caching capabilities, while traces provide lightweight checkpointing with storage overhead. However, traces support concurrent execution using `asyncio` patterns within a single task.
## Object stores and latency considerations
By default, Flyte uses object stores like S3, GCS, Azure Storage, and R2 as metadata stores. These have high latency for smaller objects, so:
- **Minimum task duration**: Tasks should take at least a second to run to amortize storage overhead.
- **Future improvements**: High-performance metastores like Redis and PostgreSQL may be supported in the future. Contact the Union team if you're interested.
## Configuring data storage
### Organization and project level
Object stores are configured at the organization level or per project/domain. Documentation for this configuration is coming soon.
### Per-run configuration
Configure raw data storage on a per-run basis using `flyte.with_runcontext`:
```python
run = flyte.with_runcontext(
raw_data_storage="s3://my-bucket/custom-path"
).run(my_task, input_data=data)
```
This allows you to control where reference data (files, directories, DataFrames) is stored for specific runs.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/life-of-a-run ===
# Life of a run
Understanding what happens when you invoke `flyte.run()` is crucial for optimizing workflow performance and debugging issues. This guide walks through each phase of task execution from submission to completion.
## Overview
When you execute `flyte.run()`, the system goes through several phases:
1. **Code analysis and preparation**: Discover environments and images
2. **Image building**: Build container images if changes are detected
3. **Code bundling**: Package your Python code
4. **Upload**: Transfer the code bundle to object storage
5. **Run creation**: Submit the run to the backend
6. **Task execution**: Execute the task in the data plane
7. **State management**: Track and persist execution state
## Phase 1: Code analysis and preparation
When `flyte.run()` is invoked:
1. **Environment discovery**: Flyte analyzes your code and finds all relevant `flyte.TaskEnvironment` instances by walking the `depends_on` hierarchy.
2. **Image identification**: Discovers unique `flyte.Image` instances used across all environments.
3. **Image building**: Starts the image building process. Images are only built if a change is detected.
> [!NOTE]
> If you invoke `flyte.run()` multiple times within the same Python process without changing code (such as in a notebook or script), the code bundling and image building steps are done only once. This can dramatically speed up iteration.
## Phase 2: Image building
Container images provide the runtime environment for your tasks:
- **Change detection**: Images are only rebuilt if changes are detected in dependencies or configuration.
- **Caching**: Previously built images are reused when possible.
- **Parallel builds**: Multiple images can be built concurrently.
For more details on container images, see [Container Images](../task-configuration/container-images).
## Phase 3: Code bundling
After images are built, your project files are bundled:
### Default: `copy_style="loaded_modules"`
By default, all Python modules referenced by the invoked tasks through module-level import statements are automatically copied. This provides a good balance between completeness and efficiency.
### Alternative: `copy_style="none"`
Skip bundling by setting `copy_style="none"` in `flyte.with_runcontext()` and adding all code into `flyte.Image`:
```python
# Add code to image
image = flyte.Image().with_source_code("/path/to/code")
# Or use Dockerfile
image = flyte.Image.from_dockerfile("Dockerfile")
# Skip bundling
run = flyte.with_runcontext(copy_style="none").run(my_task, input_data=data)
```
For more details on code packaging, see [Packaging](../task-deployment/packaging).
## Phase 4: Upload code bundle
Once the code bundle is created:
1. **Request signed URL**: The SDK sends the bundle checksum and target path to the Control Plane.
2. **Control Plane obtains URL**: The Control Plane calls the Data Plane to obtain a signed URL for that checksum and path.
3. **Direct upload**: The signed URL is returned to the SDK, which uploads the code bundle directly to the object store.
## Phase 5: Run creation and queuing
The `CreateRun` API is invoked:
1. **Copy inputs**: Input data is copied to the object store.
2. **En-queue a run**: The run is queued into the Union Control Plane.
3. **Hand off to executor**: Union Control Plane hands the task to the Executor Service in your data plane.
4. **Create action**: The parent task action (called `a0`) is created.
## Phase 6: Task execution in data plane
### Container startup
1. **Container starts**: The task container starts in your data plane.
2. **Download code bundle**: The Flyte runtime downloads the code bundle from object storage.
3. **Inflate task**: The task is inflated from the code bundle.
4. **Download inputs**: Inline inputs are downloaded from the object store.
5. **Execute task**: The task is executed with context and inputs.
### Invoking downstream tasks
If the task invokes other tasks:
1. **Controller thread**: A controller thread starts to communicate with the backend Queue Service.
2. **Monitor status**: The controller monitors the status of downstream actions.
3. **Crash recovery**: If the task crashes, the action identifier is deterministic, allowing the task to resurrect its state from Union Control Plane.
4. **Replay**: The controller efficiently replays state (even at large scale) to find missing completions and resume monitoring.
### Execution flow diagram
```mermaid
sequenceDiagram
participant Client as SDK/Client
participant Control as Control Plane (Queue Service)
participant Data as Data Plane (Executor)
participant ObjStore as Object Store
participant Container as Task Container
Client->>Client: Analyze code & discover environments
Client->>Client: Build images (if changed)
Client->>Client: Bundle code
Client->>Control: Request signed URL (checksum, path)
Control->>Data: Get signed URL for bundle
Data-->>Control: Signed URL
Control-->>Client: Signed URL
Client->>ObjStore: Upload code bundle (signed URL)
Client->>Control: CreateRun API with inputs
Control->>Data: Copy inputs
Data->>ObjStore: Write inputs
Control->>Data: Queue task (create action a0)
Data->>Container: Start container
Container->>Data: Request code bundle
Data->>ObjStore: Read code bundle
ObjStore-->>Data: Code bundle
Data-->>Container: Code bundle
Container->>Container: Inflate task
Container->>Data: Request inputs
Data->>ObjStore: Read inputs
ObjStore-->>Data: Inputs
Data-->>Container: Inputs
Container->>Container: Execute task
alt Invokes downstream tasks
Container->>Container: Start controller thread
Container->>Control: Submit downstream tasks
Control->>Data: Queue downstream actions
Container->>Control: Monitor downstream status
Control-->>Container: Status updates
end
Container->>Data: Upload outputs
Data->>ObjStore: Write outputs
Container->>Control: Complete
Control-->>Client: Run complete
```
## Action identifiers and crash recovery
Flyte uses deterministic action identifiers to enable robust crash recovery:
- **Consistent identifiers**: Action identifiers are consistently computed based on task and invocation context.
- **Re-run identical**: In any re-run, the action identifier is identical for the same invocation.
- **Multiple invocations**: Multiple invocations of the same task receive unique identifiers.
- **Efficient resurrection**: On crash, the `a0` action resurrects its state from Union Control Plane efficiently, even at large scale.
- **Replay and resume**: The controller replays execution until it finds missing completions and starts watching them.
## Downstream task execution
When downstream tasks are invoked:
1. **Action creation**: Downstream actions are created with unique identifiers.
2. **Queue assignment**: Actions are handed to an executor, which can be selected using a queue or from the general pool.
3. **Parallel execution**: Multiple downstream tasks can execute in parallel.
4. **Result aggregation**: Results are aggregated and returned to the parent task.
## Reusable containers
When using [reusable containers](../task-configuration/reusable-containers), the execution model changes:
1. **Environment spin-up**: The container environment is first spun up with configured replicas.
2. **Task allocation**: Tasks are allocated to available replicas in the environment.
3. **Scaling**: If all replicas are busy, new replicas are spun up (up to the configured maximum), or tasks are backlogged in queues.
4. **Container reuse**: The same container handles multiple task executions, reducing startup overhead.
5. **Lifecycle management**: Containers are managed according to `ReusePolicy` settings (`idle_ttl`, `scaledown_ttl`, etc.).
### Reusable container execution flow
```mermaid
sequenceDiagram
participant Control as Queue Service
participant Executor as Executor Service
participant Pool as Container Pool
participant Replica as Container Replica
Control->>Executor: Submit task
alt Reusable containers enabled
Executor->>Pool: Request available replica
alt Replica available
Pool->>Replica: Allocate task
Replica->>Replica: Execute task
Replica->>Pool: Task complete (ready for next)
else No replica available
alt Can scale up
Executor->>Pool: Create new replica
Pool->>Replica: Spin up new container
Replica->>Replica: Execute task
Replica->>Pool: Task complete
else At max replicas
Executor->>Pool: Queue task
Pool-->>Executor: Wait for available replica
Pool->>Replica: Allocate when available
Replica->>Replica: Execute task
Replica->>Pool: Task complete
end
end
else No reusable containers
Executor->>Replica: Create new container
Replica->>Replica: Execute task
Replica->>Executor: Complete & terminate
end
Replica-->>Control: Return results
```
## State replication and visualization
### Queue Service to Run Service
1. **Reliable replication**: Queue Service reliably replicates execution state back to Run Service.
2. **Eventual consistency**: The Run Service may be slightly behind the actual execution state.
3. **Visualization**: Run Service paints the entire run onto the UI.
### UI limitations
- **Current limit**: The UI is currently limited to displaying 50k actions per run.
- **Future improvements**: This limit will be increased in future releases. Contact the Union team if you need higher limits.
## Optimization opportunities
Understanding the life of a run reveals several optimization opportunities:
1. **Reuse Python process**: Run `flyte.run()` multiple times in the same process to avoid re-bundling code.
2. **Skip bundling**: Use `copy_style="none"` and bake code into images for faster startup.
3. **Reusable containers**: Use reusable containers to eliminate container startup overhead.
4. **Parallel execution**: Invoke multiple downstream tasks concurrently using `flyte.map()` or `asyncio`.
5. **Efficient data flow**: Minimize data transfer by using reference types (files, directories) instead of inline data.
6. **Caching**: Enable task caching to avoid redundant computation.
For detailed performance tuning guidance, see [Scale your workflows](./scale-your-workflows).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/scale-your-workflows ===
# Scale your workflows
Performance optimization in Flyte involves understanding the interplay between task execution overhead, data transfer, and concurrency. This guide helps you identify bottlenecks and choose the right patterns for your workload.
## Understanding performance dimensions
Performance optimization focuses on two key dimensions:
### Latency
**Goal**: Minimize end-to-end execution time for individual workflows.
**Characteristics**:
- Fast individual actions (milliseconds to seconds)
- Total action count typically less than 1,000
- Critical for interactive applications and real-time processing
- Multi-step inference, with reusing model or data in memory (use reusable containers with [@alru.cache](https://pypi.org/project/async-lru/))
**Recommended approach**:
- Use tasks for orchestration and parallelism
- Use [traces](../task-programming/traces) for fine-grained checkpointing
- Model parallelism using `asyncio` and use things methods like `asyncio.as_completed` or `asyncio.gather` to join the parallelism
- Leverage [reusable containers](../task-configuration/reusable-containers) with concurrency to eliminate startup overhead and optimize resource utilization
### Throughput
**Goal**: Maximize the number of items processed per unit time.
**Characteristics**:
- Processing large datasets (millions of items)
- High total action count (10k to 50k actions)
- Batch processing, large-scale batch inference and ETL workflows
**Recommended approach**:
- Batch workloads to reduce overhead
- Limit fanout to manage system load
- Use reusable containers with concurrency for maximum utilization
- Balance task granularity with overhead
## Task execution overhead
Understanding task overhead is critical for performance optimization. When you invoke a task, several operations occur:
| Operation | Symbol | Description |
|-----------|--------|-------------|
| **Upload data** | `u` | Time to upload input data to object store |
| **Download data** | `d` | Time to download input data from object store |
| **Enqueue task** | `e` | Time to enqueue task in Queue Service |
| **Create instance** | `t` | Time to create task container instance |
**Total overhead per task**: `2u + 2d + e + t`
This overhead includes:
- Uploading inputs from the parent task (`u`)
- Downloading inputs in the child task (`d`)
- Uploading outputs from the child task (`u`)
- Downloading outputs in the parent task (`d`)
- Enqueuing the task (`e`)
- Creating the container instance (`t`)
### The overhead principle
For efficient execution, task overhead should be much smaller than task runtime:
```
Total overhead (2u + 2d + e + t) << Task runtime
```
If task runtime is comparable to or less than overhead, consider:
1. **Batching**: Combine multiple work items into a single task
2. **Traces**: Use traces instead of tasks for lightweight operations
3. **Reusable containers**: Eliminate container creation overhead (`t`)
4. **Local execution**: Run lightweight operations within the parent task
## System architecture and data flow
To optimize performance, understand how tasks flow through the system:
1. **Control plane to data plane**: Tasks flow from the control plane (Run Service, Queue Service) to the data plane (Executor Service).
2. **Data movement**: Data moves between tasks through object storage. See [Data flow](./data-flow) for details.
3. **State replication**: Queue Service reliably replicates state back to Run Service for visualization. The Run Service may be slightly behind actual execution.
For a detailed walkthrough of task execution, see [Life of a run](./life-of-a-run).
## Optimization strategies
### 1. Use reusable containers for concurrency
[Reusable containers](../task-configuration/reusable-containers) eliminate the container creation overhead (`t`) and enable concurrent task execution:
```python
import flyte
from datetime import timedelta
# Define reusable environment
env = flyte.TaskEnvironment(
name="high-throughput",
reuse_policy=flyte.ReusePolicy(
replicas=(2, 10), # Auto-scale from 2 to 10 replicas
concurrency=5, # 5 tasks per replica = 50 max concurrent
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
)
@env.task
async def process_item(item: dict) -> dict:
# Process individual item
return {"processed": item["id"]}
```
**Benefits**:
- Eliminates container startup overhead (`t β 0`)
- Supports concurrent execution (multiple tasks per container)
- Auto-scales based on demand
- Reuses Python environment and loaded dependencies
**Limitations**:
- Concurrency is limited by CPU and I/O resources in the container
- Memory requirements scale with total working set size
- Best for I/O-bound tasks or async operations
### 2. Batch workloads to reduce overhead
For high-throughput processing, batch multiple items into a single task:
```python
@env.task
async def process_batch(items: list[dict]) -> list[dict]:
"""Process a batch of items in a single task."""
results = []
for item in items:
result = await process_single_item(item)
results.append(result)
return results
@env.task
async def process_large_dataset(dataset: list[dict]) -> list[dict]:
"""Process 1M items with batching."""
batch_size = 1000 # Adjust based on overhead calculation
batches = [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
# Process batches in parallel (1000 tasks instead of 1M)
results = await asyncio.gather(*[process_batch(batch) for batch in batches])
# Flatten results
return [item for batch_result in results for item in batch_result]
```
**Benefits**:
- Reduces total number of tasks (e.g., 1000 tasks instead of 1M)
- Amortizes overhead across multiple items
- Lower load on Queue Service and object storage
**Choosing batch size**:
1. Calculate overhead: `overhead = 2u + 2d + e + t`
2. Target task runtime: `runtime > 10 Γ overhead` (rule of thumb)
3. Adjust batch size to achieve target runtime
4. Consider memory constraints (larger batches require more memory)
### 3. Use traces for lightweight operations
[Traces](../task-programming/traces) provide fine-grained checkpointing with minimal overhead:
```python
@flyte.trace
async def fetch_data(url: str) -> dict:
"""Traced function for API call."""
response = await http_client.get(url)
return response.json()
@flyte.trace
async def transform_data(data: dict) -> dict:
"""Traced function for transformation."""
return {"transformed": data}
@env.task
async def process_workflow(urls: list[str]) -> list[dict]:
"""Orchestrate using traces instead of tasks."""
results = []
for url in urls:
data = await fetch_data(url)
transformed = await transform_data(data)
results.append(transformed)
return results
```
**Benefits**:
- Only storage overhead (no task orchestration overhead)
- Runs in the same Python process with asyncio parallelism
- Provides checkpointing and resumption
- Visible in execution logs and UI
**Trade-offs**:
- No caching (use tasks for cacheable operations)
- Shares resources with the parent task (CPU, memory)
- Storage writes may still be slow due to object store latency
**When to use traces**:
- API calls and external service interactions
- Deterministic transformations that need checkpointing
- Operations taking more than 1 second (to amortize storage overhead)
### 4. Limit fanout for system stability
The UI and system have limits on the number of actions per run:
- **Current limit**: 50k actions per run
- **Future**: Higher limits will be supported (contact the Union team if needed)
**Example: Control fanout with batching**
```python
@env.task
async def process_million_items(items: list[dict]) -> list[dict]:
"""Process 1M items with controlled fanout."""
# Target 10k tasks, each processing 100 items
batch_size = 100
max_fanout = 10000
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
# Use flyte.map for parallel execution
results = await flyte.map(process_batch, batches)
return [item for batch in results for item in batch]
```
### 5. Optimize data transfer
Minimize data transfer overhead by choosing appropriate data types:
**Use reference types for large data**:
```python
from flyte.io import File, Directory, DataFrame
@env.task
async def process_large_file(input_file: File) -> File:
"""Files passed by reference, not copied."""
# Download only when needed
local_path = input_file.download()
# Process file
result_path = process(local_path)
# Upload result
return File.new_remote(result_path)
```
**Use inline types for small data**:
```python
@env.task
async def process_metadata(metadata: dict) -> dict:
"""Small dicts passed inline efficiently."""
return {"processed": metadata}
```
**Guideline**:
- **< 10 MB**: Use inline types (primitives, small dicts, lists)
- **> 10 MB**: Use reference types (File, Directory, DataFrame)
- **Adjust**: Use `max_inline_io` in `TaskEnvironment` to change the threshold
See [Data flow](./data-flow) for details on data types and transport.
### 6. Leverage caching
Enable [caching](../task-configuration/caching) to avoid redundant computation:
```python
@env.task(cache="auto")
async def expensive_computation(input_data: dict) -> dict:
"""Automatically cached based on inputs."""
# Expensive operation
return result
```
**Benefits**:
- Skips re-execution for identical inputs
- Reduces overall workflow runtime
- Preserves resources for new computations
**When to use**:
- Deterministic tasks (same inputs β same outputs)
- Expensive computations (model training, large data processing)
- Stable intermediate results
### 7. Parallelize with `flyte.map`
Use [`flyte.map`](../task-programming/fanout) for data-parallel workloads:
```python
@env.task
async def process_item(item: dict) -> dict:
return {"processed": item}
@env.task
async def parallel_processing(items: list[dict]) -> list[dict]:
"""Process items in parallel using map."""
results = await flyte.map(process_item, items)
return results
```
**Benefits**:
- Automatic parallelization
- Dynamic scaling based on available resources
- Built-in error handling and retries
**Best practices**:
- Combine with batching to control fanout
- Use with reusable containers for maximum throughput
- Consider memory and resource limits
## Performance tuning workflow
Follow this workflow to optimize your Flyte workflows:
1. **Profile**: Measure task execution times and identify bottlenecks.
2. **Calculate overhead**: Estimate `2u + 2d + e + t` for your tasks.
3. **Compare**: Check if `task runtime >> overhead`. If not, optimize.
4. **Batch**: Increase batch size to amortize overhead.
5. **Reusable containers**: Enable reusable containers to eliminate `t`.
6. **Traces**: Use traces for lightweight operations within tasks.
7. **Cache**: Enable caching for deterministic, expensive tasks.
8. **Limit fanout**: Keep total actions below 50k (target 10k-20k).
9. **Monitor**: Use the UI to monitor execution and identify issues.
10. **Iterate**: Continuously refine based on performance metrics.
## Real-world example: PyIceberg batch processing
For a comprehensive example of efficient data processing with Flyte, see the [PyIceberg parallel batch aggregation example](https://github.com/flyteorg/flyte-sdk/blob/main/examples/data_processing/pyiceberg_example.py). This example demonstrates:
- **Zero-copy data passing**: Pass file paths instead of data between tasks
- **Reusable containers with concurrency**: Maximize CPU utilization across workers
- **Parallel file processing**: Use `asyncio.gather()` to process multiple files concurrently
- **Efficient batching**: Distribute parquet files across worker tasks
Key pattern from the example:
```python
# Instead of loading entire table, get file paths
file_paths = [task.file.file_path for task in table.scan().plan_files()]
# Distribute files across partitions (zero-copy!)
partition_files = distribute_files(file_paths, num_partitions)
# Process partitions in parallel
results = await asyncio.gather(*[
aggregate_partition(files, partition_id)
for partition_id, files in enumerate(partition_files)
])
```
This approach achieves true parallel file processing without loading the entire dataset into memory.
## Example: Optimizing a data pipeline
### Before optimization
```python
@env.task
async def process_item(item: dict) -> dict:
# Very fast operation (~100ms)
return {"processed": item["id"]}
@env.task
async def process_dataset(items: list[dict]) -> list[dict]:
# Create 1M tasks
results = await asyncio.gather(*[process_item(item) for item in items])
return results
```
**Issues**:
- 1M tasks created (exceeds UI limit)
- Task overhead >> task runtime (100ms task, seconds of overhead)
- High load on Queue Service and object storage
### After optimization
```python
# Use reusable containers
env = flyte.TaskEnvironment(
name="optimized-pipeline",
reuse_policy=flyte.ReusePolicy(
replicas=(5, 20),
concurrency=10,
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
)
@env.task
async def process_batch(items: list[dict]) -> list[dict]:
# Process batch of items
return [{"processed": item["id"]} for item in items]
@env.task
async def process_dataset(items: list[dict]) -> list[dict]:
# Create 1000 tasks (batch size 1000)
batch_size = 1000
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
results = await flyte.map(process_batch, batches)
return [item for batch in results for item in batch]
```
**Improvements**:
- 1000 tasks instead of 1M (within limits)
- Batch runtime ~100 seconds (100ms Γ 1000 items)
- Reusable containers eliminate startup overhead
- Concurrency enables high throughput (200 concurrent tasks max)
## When to contact the Union team
Reach out to the Union team if you:
- Need more than 50k actions per run
- Want to use high-performance metastores (Redis, PostgreSQL) instead of object stores
- Have specific performance requirements or constraints
- Need help profiling and optimizing your workflows
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/batch-inference ===
# Maximize GPU utilization for batch inference
GPUs are expensive. When running batch inference, the single biggest cost driver is **idle GPU time** β cycles where the GPU sits waiting with nothing to do. Understanding why this happens and how to fix it is the key to cost-effective batch inference.
## Why GPU utilization drops
A typical inference task does three things:
1. **Load data** β read from storage, deserialize, preprocess (CPU/IO-bound)
2. **Run inference** β forward pass through the model (GPU-bound)
3. **Post-process** β format results, write outputs (CPU/IO-bound)
When these steps run sequentially, the GPU is idle during steps 1 and 3. For many workloads, data loading and preprocessing dominate wall-clock time, leaving the GPU busy for only a fraction of the total:
```mermaid
gantt
title Sequential execution β GPU idle during CPU/IO work
dateFormat X
axisFormat %s
section Task 1
Load data (CPU/IO) :a1, 0, 3
Inference (GPU) :a2, after a1, 2
Post-process (CPU/IO) :a3, after a2, 1
section Task 2
Load data (CPU/IO) :b1, after a3, 3
Inference (GPU) :b2, after b1, 2
Post-process (CPU/IO) :b3, after b2, 1
section GPU
Idle :crit, g1, 0, 3
Busy :active, g2, 3, 5
Idle :crit, g3, 5, 9
Busy :active, g4, 9, 11
Idle :crit, g5, 11, 12
```
In this example, the GPU is busy for only 4 out of 12 time units β **33% utilization**. The rest is wasted waiting for CPU and IO operations.
## Serving vs in-process batch inference
There are two common approaches to batch inference: sending requests to a **hosted model server** (serving), or running the model **in-process** alongside data loading. Each has distinct trade-offs:
| | Hosted serving | In-process (Flyte) |
|---|---|---|
| **Architecture** | Separate inference server (e.g. Triton, vLLM server, TGI) accessed over the network | Model loaded directly in the task process, inference via `DynamicBatcher` |
| **Data transfer** | Every request serialized over the network; large payloads add latency | Zero-copy β data stays in-process, no serialization overhead |
| **Backpressure** | Hard to implement; push-based architecture can overwhelm the server or drop requests | Two levels: `DynamicBatcher` queue blocks producers when full, and Flyte's task scheduling automatically queues new inference tasks when replicas are busy β backpressure propagates end-to-end without any extra code |
| **Utilization** | Servers are often over-provisioned to maintain availability, leading to low average utilization | Batcher continuously fills the GPU with work from concurrent producers |
| **Multi-model** | Each model needs its own serving deployment, load balancer, and scaling config | Multiple models can time-share the same GPU β when one model finishes, the next is loaded automatically via reusable containers, no container orchestration required |
| **Scaling** | Requires separate infrastructure for the serving layer (load balancers, autoscalers, health checks) | Scales with Flyte β replicas auto-scale based on demand |
| **Cost** | Pay for always-on serving infrastructure even during low-traffic periods | Pay only for the duration of the batch job |
| **Fault tolerance** | Need retries, circuit breakers, and timeout handling for network failures | Failures are local; Flyte handles retries and recovery at the task level |
| **Best for** | Real-time / low-latency serving with unpredictable request patterns | Large-scale batch processing with known datasets |
For batch workloads, in-process inference eliminates the network overhead and infrastructure complexity of a serving layer while achieving higher GPU utilization through intelligent batching.
## Solution: `DynamicBatcher`
`DynamicBatcher` from `flyte.extras` solves the utilization problem by **separating data loading from inference** and running them concurrently. Multiple async producers load and preprocess data while a single consumer feeds the GPU in optimally-sized batches:
```mermaid
flowchart LR
subgraph producers ["Concurrent producers (CPU/IO)"]
P1["Stream 1: load + preprocess"]
P2["Stream 2: load + preprocess"]
P3["Stream N: load + preprocess"]
end
subgraph batcher ["DynamicBatcher"]
Q["Queue with backpressure"]
A["Aggregation loop (assembles cost-budgeted batches)"]
Q --> A
end
subgraph consumer ["Processing loop (GPU)"]
G["process_fn / inference_fn (batched forward pass)"]
end
P1 --> Q
P2 --> Q
P3 --> Q
A --> G
```
The batcher runs two internal loops:
1. **Aggregation loop** β drains the submission queue and assembles batches that respect a cost budget (`target_batch_cost`), a maximum size (`max_batch_size`), and a timeout (`batch_timeout_s`). This ensures the GPU always receives optimally-sized batches.
2. **Processing loop** β pulls assembled batches and calls your processing function, resolving each record's future with its result.
This pipelining means the GPU is processing batch N while data for batch N+1 is being loaded and assembled β **eliminating idle time**.
### Basic usage
```python
from flyte.extras import DynamicBatcher
async def process(batch: list[dict]) -> list[str]:
"""Your batch processing function. Must return results in the same order as the input."""
return [heavy_computation(item) for item in batch]
async with DynamicBatcher(
process_fn=process,
target_batch_cost=1000, # cost budget per batch
max_batch_size=64, # hard cap on records per batch
batch_timeout_s=0.05, # max wait time before dispatching a partial batch
max_queue_size=5_000, # queue size for backpressure
) as batcher:
futures = []
for record in my_records:
future = await batcher.submit(record, estimated_cost=10)
futures.append(future)
results = await asyncio.gather(*futures)
```
Each call to `submit()` is non-blocking β it enqueues the record and immediately returns a `Future`. When the queue is full, `submit()` awaits until space is available, providing natural backpressure to prevent producers from overwhelming the GPU.
### Cost estimation
The batcher uses cost estimates to decide how many records to group into each batch. You can provide costs in several ways (checked in order of precedence):
1. **Explicit** β pass `estimated_cost` to `submit()`
2. **Estimator function** β pass `cost_estimator` to the constructor
3. **Protocol** β implement `estimate_cost()` on your record type
4. **Default** β falls back to `default_cost` (default: 1)
## `TokenBatcher` for LLM inference
For LLM workloads, `TokenBatcher` is a convenience subclass that uses token-aware parameter names:
```python
from dataclasses import dataclass
from flyte.extras import TokenBatcher
@dataclass
class Prompt:
text: str
def estimate_tokens(self) -> int:
"""Rough token estimate (~4 chars per token)."""
return len(self.text) // 4 + 1
async def inference(batch: list[Prompt]) -> list[str]:
"""Run batched inference through your model."""
texts = [p.text for p in batch]
outputs = model.generate(texts, sampling_params)
return [o.outputs[0].text for o in outputs]
async with TokenBatcher(
inference_fn=inference,
target_batch_tokens=32_000, # token budget per batch
max_batch_size=256,
) as batcher:
future = await batcher.submit(Prompt(text="What is 2+2?"))
result = await future
```
`TokenBatcher` checks the `TokenEstimator` protocol (`estimate_tokens()`) in addition to `CostEstimator` (`estimate_cost()`), making it natural to work with prompt types.
## Monitoring utilization
`DynamicBatcher` exposes a `stats` property with real-time metrics:
```python
stats = batcher.stats
print(f"Utilization: {stats.utilization:.1%}") # fraction of time spent processing
print(f"Records processed: {stats.total_completed}")
print(f"Batches dispatched: {stats.total_batches}")
print(f"Avg batch size: {stats.avg_batch_size:.1f}")
print(f"Busy time: {stats.busy_time_s:.1f}s")
print(f"Idle time: {stats.idle_time_s:.1f}s")
```
| Metric | Description |
|---|---|
| `utilization` | Fraction of wall-clock time spent inside `process_fn` (0.0β1.0). Target: > 0.9. |
| `total_submitted` | Total records submitted via `submit()` |
| `total_completed` | Total records whose futures have been resolved |
| `total_batches` | Number of batches dispatched to `process_fn` |
| `avg_batch_size` | Running average records per batch |
| `avg_batch_cost` | Running average cost per batch |
| `busy_time_s` | Cumulative seconds spent inside `process_fn` |
| `idle_time_s` | Cumulative seconds the processing loop waited for batches |
If utilization is low, consider:
- **Increasing concurrency** β more concurrent producers means the batcher has more records to assemble into batches
- **Reducing `batch_timeout_s`** β dispatch partial batches faster instead of waiting
- **Increasing `max_queue_size`** β allow more records to be buffered ahead of the GPU
- **Adding more data streams** β ensure the GPU always has work queued up
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps ===
# Configure apps
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
`[[AppEnvironment]]`s allows you to configure the environment in which your app runs, including the container image, compute resources, secrets, domains, scaling behavior, and more.
Similar to `[[TaskEnvironment]]`, configuration can be set when creating the `[[AppEnvironment]]` object. Unlike tasks, apps are long-running services, so they have additional configuration options specific to web services:
- `port`: What port the app listens on
- `command` and `args`: How to start the app
- `scaling`: Autoscaling configuration for handling variable load
- `domain`: Custom domains and subdomains for your app
- `requires_auth`: Whether the app requires authentication to access
- `depends_on`: Other app or task environments that the app depends on
## Hello World example
Here's a complete example of deploying a simple Streamlit "hello world" app with a custom subdomain.
There are two ways to build apps in Flyte:
1. Defining `AppEnvironment(.., args=[...])` to run the app with the underlying `fserve` command.
2. Defining `@app_env.server` to run the app with a custom server function.
### Using fserve args
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1")
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="hello-world-app",
image=image,
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
domain=flyte.app.Domain(subdomain="hello"),
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
# Deploy the app
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/hello-world-app.py*
This example demonstrates:
- Creating a custom Docker image with Streamlit
- Setting the `args` to run the Streamlit hello app, which uses the underlying `fserve` command to run the app.
- Configuring the port
- Setting resource limits
- Disabling authentication (for public access)
- Using a custom subdomain
### Using @app_env.server
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1")
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="hello-world-app-server",
image=image,
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
domain=flyte.app.Domain(subdomain="hello-server"),
)
@app_env.server
def server():
import subprocess
subprocess.run(["streamlit", "hello", "--server.port", "8080"], check=False)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
# Deploy the app
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/hello-world-app-server.py*
This example demonstrates:
- Creating a custom Docker image with Streamlit
- Using the `@app_env.server` decorator to define a server function that runs the Streamlit hello app.
- Configuring the port
- Setting resource limits
- Disabling authentication (for public access)
- Using a custom subdomain
Once deployed, your app will be accessible at the generated URL or your custom subdomain.
## Differences from TaskEnvironment
While `AppEnvironment` inherits from `Environment` (the same base class as `TaskEnvironment`), it has several app-specific parameters:
| Parameter | AppEnvironment | TaskEnvironment | Description |
|-----------|----------------|-----------------|-------------|
| `type` | β | β | Type of app (e.g., "FastAPI", "Streamlit") |
| `port` | β | β | Port the app listens on |
| `args` | β | β | Arguments to pass to the app |
| `command` | β | β | Command to run the app |
| `requires_auth` | β | β | Whether app requires authentication |
| `scaling` | β | β | Autoscaling configuration |
| `domain` | β | β | Custom domain/subdomain |
| `links` | β | β | Links to include in the App UI page |
| `include` | β | β | Files to include in app |
| `parameters` | β | β | Parameters to pass to app |
| `cluster_pool` | β | β | Cluster pool for deployment |
Parameters like `image`, `resources`, `secrets`, `env_vars`, and `depends_on` are shared between both environment types. See the [task configuration](../task-configuration/_index) docs for details on these shared parameters.
## Configuration topics
Learn more about configuring apps:
- **Configure apps > App environment settings**: Images, resources, secrets, and app-specific settings like `type`, `port`, `args`, `requires_auth`
- **Configure apps > App environment settings > App startup**: Understanding the difference between `args` and `command`
- **Configure apps > Including additional files**: How to include additional files needed by your app
- **Configure apps > Passing parameters into app environments**: Pass parameters to your app at deployment time
- **Configure apps > App environment settings > `scaling`**: Configure scaling up and down based on traffic with idle TTL
- **Configure apps > Apps depending on other environments**: Use `depends_on` to deploy dependent apps together
## Subpages
- **Configure apps > App environment settings**
- **Configure apps > Including additional files**
- **Configure apps > Passing parameters into app environments**
- **Configure apps > /// script**
- **Configure apps > Apps depending on other environments**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/app-environment-settings ===
# App environment settings
`[[AppEnvironment]]`s control how your apps run in Flyte, including images, resources, secrets, startup behavior, and autoscaling.
## Shared environment settings
`[[AppEnvironment]]`s share many configuration options with `[[TaskEnvironment]]`s:
- **Images**: See [Container images](../task-configuration/container-images/) for details on creating and using container images
- **Resources**: See [Resources](../task-configuration/resources/) for CPU, memory, GPU, and storage configuration
- **Secrets**: See [Secrets](../task-configuration/secrets/) for injecting secrets into your app
- **Environment variables**: Set via the `env_vars` parameter (same as tasks)
- **Cluster pools**: Specify via the `cluster_pool` parameter
## App-specific environment settings
For complete parameter documentation and type signatures, see the [`AppEnvironment` API reference](../../api-reference/flyte-sdk/packages/flyte.app/appenvironment).
### `type`
The `type` parameter is an optional string that identifies what kind of app this is. It's used for organizational purposes and may be used by the UI or tooling to display or filter apps.
```python
app_env = flyte.app.AppEnvironment(
name="my-fastapi-app",
type="FastAPI",
# ...
)
```
When using specialized app environments like `FastAPIAppEnvironment`, the type is automatically set. For custom apps, you can set it to any string value.
### `port`
The `port` parameter specifies which port your app listens on. It can be an integer or a `Port` object.
```python
# Using an integer (simple case)
app_env = flyte.app.AppEnvironment(name="my-app", port=8080, ...)
# Using a Port object (more control)
app_env = flyte.app.AppEnvironment(
name="my-app",
port=flyte.app.Port(port=8080),
# ...
)
```
The default port is `8080`. Your app should listen on this port (or the port you specify).
> [!NOTE]
> Ports 8012, 8022, 8112, 9090, and 9091 are reserved and cannot be used for apps.
### `args`
The `args` parameter specifies arguments to pass to your app's command. This is typically used when you need to pass additional arguments to the command specified in `command`, or when using the default command behavior.
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
# ...
)
```
`args` can be either a string (which will be shell-split) or a list of strings:
```python
# String form (will be shell-split)
args="--option1 value1 --option2 value2"
# List form (more explicit)
args=["--option1", "value1", "--option2", "value2"]
```
#### Environment variable substitution
Environment variables are automatically substituted in `args` strings when they start with the `$` character. This works for both:
- Values from `env_vars`
- Secrets that are specified as environment variables (via `as_env_var` in `flyte.Secret`)
The `$VARIABLE_NAME` syntax will be replaced with the actual environment variable value at runtime:
```python
# Using env_vars
app_env = flyte.app.AppEnvironment(
name="my-app",
env_vars={"API_KEY": "secret-key-123"},
args="--api-key $API_KEY", # $API_KEY will be replaced with "secret-key-123"
# ...
)
# Using secrets
app_env = flyte.app.AppEnvironment(
name="my-app",
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
args=["--api-key", "$AUTH_SECRET"], # $AUTH_SECRET will be replaced with the secret value
# ...
)
```
This is particularly useful for passing API keys or other sensitive values to command-line arguments without hardcoding them in your code. The substitution happens at runtime, ensuring secrets are never exposed in your code or configuration files.
> [!TIP]
> For most `AppEnvironment`s, use `args` instead of `command` to specify the app startup command
> in the container. This is because `args` will use the `fserve` command to run the app, which
> unlocks features like local code bundling and file/directory mounting via parameter injection.
### `command`
The `command` parameter specifies the full command to run your app. If not specified, Flyte will use a default command that runs your app via `fserve`, which is the Python executable provided
by `flyte` to run apps.
```python
# Explicit command
app_env = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# ...
)
# Using default command (recommended for most cases)
# When command is None, Flyte generates a command based on your app configuration
app_env = flyte.app.AppEnvironment(name="my-app", ...) # command=None by default
```
> [!TIP]
> For most apps, especially when using specialized app environments like `FastAPIAppEnvironment`, you don't need to specify `command` as it's automatically configured. Use `command` when you need
> to specify the raw container command, e.g. when running a non-Python app or when you have all
> of the dependencies and data used by the app available in the container.
### `requires_auth`
The `requires_auth` parameter controls whether the app requires authentication to access. By default, apps require authentication (`requires_auth=True`).
```python
# Public app (no authentication required)
app_env = flyte.app.AppEnvironment(
name="public-dashboard",
requires_auth=False,
# ...
)
# Private app (authentication required - default)
app_env = flyte.app.AppEnvironment(
name="internal-api",
requires_auth=True,
# ...
) # Default
```
When `requires_auth=True`, users must authenticate with Flyte to access the app. When `requires_auth=False`, the app is publicly accessible (though it may still require API keys or other app-level authentication).
### `domain`
The `domain` parameter specifies a custom domain or subdomain for your app. Use `flyte.app.Domain` to configure a subdomain or custom domain.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
domain=flyte.app.Domain(subdomain="myapp"),
# ...
)
```
### `links`
The `links` parameter adds links to the App UI page. Use `flyte.app.Link` objects to specify relative or absolute links with titles.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
links=[
flyte.app.Link(path="/docs", title="API Documentation", is_relative=True),
flyte.app.Link(path="/health", title="Health Check", is_relative=True),
flyte.app.Link(path="https://www.example.com", title="External link", is_relative=False),
],
# ...
)
```
### `include`
The `include` parameter specifies files and directories to include in the app bundle. Use glob patterns or explicit paths to include code files needed by your app.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
include=["*.py", "models/", "utils/", "requirements.txt"],
# ...
)
```
> [!NOTE]
> Learn more about including additional files in your app deployment [here](./including-additional-files).
### `parameters`
The `parameters` parameter passes parameters to your app at deployment time. Parameters can be primitive values, files, directories, or delayed values like `RunOutput` or `AppEndpoint`.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
parameters=[
flyte.app.Parameter(name="config", value="foo", env_var="BAR"),
flyte.app.Parameter(name="model", value=flyte.io.File(path="s3://bucket/model.pkl"), mount="/mnt/model"),
flyte.app.Parameter(name="data", value=flyte.io.File(path="s3://bucket/data.pkl"), mount="/mnt/data"),
],
# ...
)
```
> [!NOTE]
> Learn more about passing parameters to your app at deployment time [here](./passing-parameters).
### `scaling`
The `scaling` parameter configures autoscaling behavior for your app. Use `flyte.app.Scaling` to set replica ranges and idle TTL.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
scaling=flyte.app.Scaling(
replicas=(1, 5),
scaledown_after=300, # Scale down after 5 minutes of idle time
),
# ...
)
```
> [!NOTE]
> Learn more about autoscaling apps [here](./auto-scaling-apps).
### `depends_on`
The `depends_on` parameter specifies environment dependencies. When you deploy an app, all dependencies are deployed first.
```python
backend_env = flyte.app.AppEnvironment(name="backend-api", ...)
frontend_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[backend_env], # backend-api will be deployed first
# ...
)
```
> [!NOTE]
> Learn more about app environment dependencies [her e](./apps-depending-on-environments).
## App startup
There are two ways to start up an app in Flyte:
1. With a server function using `@app_env.server`
2. As a container command using `command` or `args`
### Server decorator via `@app_env.server`
The server function is a Python function that runs the app. It is defined using the `@app_env.server` decorator.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "uvicorn",
# "flyte>=2.0.0b52",
# ]
# ///
import fastapi
import uvicorn
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = fastapi.FastAPI()
env = FastAPIAppEnvironment(
name="configure-fastapi-example",
app=app,
image=flyte.Image.from_uv_script(__file__, name="configure-fastapi-example"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
port=8080,
)
@env.server
def server():
print("Starting server...")
uvicorn.run(app, port=8080)
@app.get("/")
async def root() -> dict:
return {"message": "Hello from FastAPI!"}
# {{/docs-fragment fastapi-app}}
# {{docs-fragment on-startup-decorator}}
state = {}
@env.on_startup
async def app_startup():
print("App started up")
state["data"] = ["Here's", "some", "data"]
# {{/docs-fragment on-startup-decorator}}
# {{docs-fragment on-shutdown-decorator}}
@env.on_shutdown
async def app_shutdown():
print("App shut down")
state.clear() # clears the data
# {{/docs-fragment on-shutdown-decorator}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import logging
flyte.init_from_config(log_level=logging.DEBUG)
deployed_app = flyte.serve(env)
print(f"App served at: {deployed_app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/fastapi-server-example.py*
The `@app_env.server` decorator allows you to define a synchronous or asynchronous function that runs the app, either
with a server start command like `uvicorn.run`, [`HTTPServer.serve_forever`](https://docs.python.org/3/library/http.server.html), etc.
> [!NOTE]
> Generally the `[[FastAPIAppEnvironment]]` handles serving automatically under the hood,
> the example above just shows how the `@app_env.server` decorator can be used to define a server function
> that runs the app.
#### Startup hook
The server function is called after the app is started up, and before the app is shut down. It is defined using the `@app_env.on_startup` decorator. This is useful if you need to load any state or external connections needed to run the
app before it starts.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "uvicorn",
# "flyte>=2.0.0b52",
# ]
# ///
import fastapi
import uvicorn
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = fastapi.FastAPI()
env = FastAPIAppEnvironment(
name="configure-fastapi-example",
app=app,
image=flyte.Image.from_uv_script(__file__, name="configure-fastapi-example"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
port=8080,
)
@env.server
def server():
print("Starting server...")
uvicorn.run(app, port=8080)
@app.get("/")
async def root() -> dict:
return {"message": "Hello from FastAPI!"}
# {{/docs-fragment fastapi-app}}
# {{docs-fragment on-startup-decorator}}
state = {}
@env.on_startup
async def app_startup():
print("App started up")
state["data"] = ["Here's", "some", "data"]
# {{/docs-fragment on-startup-decorator}}
# {{docs-fragment on-shutdown-decorator}}
@env.on_shutdown
async def app_shutdown():
print("App shut down")
state.clear() # clears the data
# {{/docs-fragment on-shutdown-decorator}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import logging
flyte.init_from_config(log_level=logging.DEBUG)
deployed_app = flyte.serve(env)
print(f"App served at: {deployed_app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/fastapi-server-example.py*
#### Shutdown hook
The server function is called before the app instance shuts down during scale down. It is defined using the
`@app_env.on_shutdown` decorator. This is useful if you need to clean up any state or external connections in the
container running the app.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "uvicorn",
# "flyte>=2.0.0b52",
# ]
# ///
import fastapi
import uvicorn
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = fastapi.FastAPI()
env = FastAPIAppEnvironment(
name="configure-fastapi-example",
app=app,
image=flyte.Image.from_uv_script(__file__, name="configure-fastapi-example"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
port=8080,
)
@env.server
def server():
print("Starting server...")
uvicorn.run(app, port=8080)
@app.get("/")
async def root() -> dict:
return {"message": "Hello from FastAPI!"}
# {{/docs-fragment fastapi-app}}
# {{docs-fragment on-startup-decorator}}
state = {}
@env.on_startup
async def app_startup():
print("App started up")
state["data"] = ["Here's", "some", "data"]
# {{/docs-fragment on-startup-decorator}}
# {{docs-fragment on-shutdown-decorator}}
@env.on_shutdown
async def app_shutdown():
print("App shut down")
state.clear() # clears the data
# {{/docs-fragment on-shutdown-decorator}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import logging
flyte.init_from_config(log_level=logging.DEBUG)
deployed_app = flyte.serve(env)
print(f"App served at: {deployed_app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/fastapi-server-example.py*
### Container command via `command` vs `args`
The difference between `args` and `command` is crucial for properly configuring how your app starts.
- **`command`**: The full command to run your app, for example, `"streamlit hello --server.port 8080"`. For most use
cases, you don't need to specify `command` as it's automatically configured, and uses the `fserve` executable to
run the app. `fserve` does additional setup for you, like setting up the code bundle and loading [parameters](./passing-parameters) if provided, so it's highly recommended to use the default command.
- **`args`**: Arguments to pass to your app's command (used with the default Flyte command or your custom command). The
`fserve` executable takes in additional arguments, which you can specify as the arguments needed to run your app, e.g.
`uvicorn run main.py --server.port 8080`.
#### Default startup behavior
When you don't specify a `command`, Flyte generates a default command that uses `fserve` to run your app. This default command handles:
- Setting up the code bundle
- Configuring the version
- Setting up project/domain context
- Injecting parameters if provided
The default command looks like:
```bash
fserve --version --project --domain --
```
So if you specify `args`, they'll be appended after the `--` separator.
#### Using args with the default command
When you use `args` without specifying `command`, the args are passed to the default Flyte command:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# You typically don't need to specify command or args, since the
# FastAPIAppEnvironment automatically uses the bundled code to serve the
# app via uvicorn.
)
# {{/docs-fragment fastapi-auto-command}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py*
This effectively runs:
```bash
fserve --version ... --project ... --domain ... -- streamlit run main.py --server.port 8080
```
#### Using an explicit command
When you specify a `command`, it completely replaces the default command:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# You typically don't need to specify command or args, since the
# FastAPIAppEnvironment automatically uses the bundled code to serve the
# app via uvicorn.
)
# {{/docs-fragment fastapi-auto-command}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py*
This runs exactly:
```bash
streamlit hello --server.port 8080
```
#### Using a command with args
You can combine both, though this is less common:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# You typically don't need to specify command or args, since the
# FastAPIAppEnvironment automatically uses the bundled code to serve the
# app via uvicorn.
)
# {{/docs-fragment fastapi-auto-command}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py*
#### FastAPIAppEnvironment example
When using `FastAPIAppEnvironment`, the command is automatically configured to run uvicorn:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# You typically don't need to specify command or args, since the
# FastAPIAppEnvironment automatically uses the bundled code to serve the
# app via uvicorn.
)
# {{/docs-fragment fastapi-auto-command}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py*
The `FastAPIAppEnvironment` automatically:
1. Detects the module and variable name of your FastAPI app
2. Uses an internal server function to start the app via `uvicorn.run`.
3. Handles all the startup configuration for you
## Shared settings
For more details on shared settings like images, resources, and secrets, refer to the [task configuration](../task-configuration/_index) documentation.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/including-additional-files ===
# Including additional files
When your app needs additional files beyond the main script (like utility modules, configuration files, or data files), you can use the `include` parameter to specify which files to bundle with your app.
## How include works
The `include` parameter takes a list of file paths (relative to the directory containing your app definition). These files are bundled together and made available in the app container at runtime.
```python
include=["main.py", "utils.py", "config.yaml"]
```
## When to use include
Use `include` when:
- Your app spans multiple Python files (modules)
- You have configuration files that your app needs
- You have data files or templates your app uses
- You want to ensure specific files are available in the container
> [!NOTE]
> If you're using specialized app environments like `FastAPIAppEnvironment`, Flyte automatically detects and includes the necessary files, so you may not need to specify `include` explicitly.
## Examples
### Multi-file Streamlit app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""A custom Streamlit app with multiple files."""
import pathlib
import flyte
import flyte.app
# {{docs-fragment app-env}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
app_env = flyte.app.AppEnvironment(
name="streamlit-multi-file-app",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py"], # Include your app files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.deploy(app_env)
print(f"Deployed app: {app[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/multi_file_streamlit.py*
In this example:
- `main.py` is your main Streamlit app file
- `utils.py` contains helper functions used by `main.py`
- Both files are included in the app bundle
### Multi-file FastAPI app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py*
### App with configuration files
```python
include=["app.py", "config.yaml", "templates/"]
```
## File discovery
When using specialized app environments like `FastAPIAppEnvironment`, Flyte uses code introspection to automatically discover and include the necessary files. This means you often don't need to manually specify `include`.
However, if you have files that aren't automatically detected (like configuration files, data files, or templates), you should explicitly list them in `include`.
## Path resolution
Files in `include` are resolved relative to the directory containing your app definition file. For example:
```
project/
βββ apps/
β βββ app.py # Your app definition
β βββ utils.py # Included file
β βββ config.yaml # Included file
```
In `app.py`:
```python
include=["utils.py", "config.yaml"] # Relative to apps/ directory
```
## Best practices
1. **Only include what you need**: Don't include unnecessary files as it increases bundle size
2. **Use relative paths**: Always use paths relative to your app definition file
3. **Include directories**: You can include entire directories, but be mindful of size
4. **Test locally**: Verify your includes work by testing locally before deploying
5. **Check automatic discovery**: Specialized app environments may already include files automatically
## Limitations
- Large files or directories can slow down deployment
- Binary files are supported but consider using data storage (S3, etc.) for very large files
- The bundle size is limited by your Flyte cluster configuration
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/passing-parameters ===
# Passing parameters into app environments
`[[AppEnvironment]]`s support various parameter types that can be passed at deployment time. This includes primitive values, files, directories, and delayed values like `RunOutput` and `AppEndpoint`.
## Parameter types overview
There are several parameter types:
- **Primitive values**: Strings, numbers, booleans
- **Files**: `flyte.io.File` objects
- **Directories**: `flyte.io.Dir` objects
- **Delayed values**: `RunOutput` (from task runs) or `AppEndpoint` (inject endpoint urls of other apps)
## Basic parameter types
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "scikit-learn",
# "joblib",
# ]
# ///
"""Examples showing different ways to pass parameters into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-parameter-types}}
# String parameters
app_env = flyte.app.AppEnvironment(
name="configurable-app",
parameters=[
flyte.app.Parameter(name="environment", value="production"),
flyte.app.Parameter(name="log_level", value="INFO"),
],
# ...
)
# File parameters
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory parameters
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
parameters=[
flyte.app.Parameter(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-parameter-types}}
# {{docs-fragment runoutput-example}}
# Delayed parameters with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app parameter
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed parameters with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
parameters=[
flyte.app.Parameter(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-parameters-examples.py*
## Delayed values
Delayed values are parameters whose actual values are materialized at deployment time.
### RunOutput
Use `RunOutput` to pass outputs from task runs as app parameters:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "scikit-learn",
# "joblib",
# ]
# ///
"""Examples showing different ways to pass parameters into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-parameter-types}}
# String parameters
app_env = flyte.app.AppEnvironment(
name="configurable-app",
parameters=[
flyte.app.Parameter(name="environment", value="production"),
flyte.app.Parameter(name="log_level", value="INFO"),
],
# ...
)
# File parameters
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory parameters
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
parameters=[
flyte.app.Parameter(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-parameter-types}}
# {{docs-fragment runoutput-example}}
# Delayed parameters with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app parameter
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed parameters with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
parameters=[
flyte.app.Parameter(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-parameters-examples.py*
The `type` argument is required and must be one of `string`, `file`, or `directory`.
When the app is deployed, it will make the remote calls needed to figure out the
actual value of the parameter.
### AppEndpoint
Use `AppEndpoint` to pass endpoints from other apps:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "scikit-learn",
# "joblib",
# ]
# ///
"""Examples showing different ways to pass parameters into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-parameter-types}}
# String parameters
app_env = flyte.app.AppEnvironment(
name="configurable-app",
parameters=[
flyte.app.Parameter(name="environment", value="production"),
flyte.app.Parameter(name="log_level", value="INFO"),
],
# ...
)
# File parameters
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory parameters
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
parameters=[
flyte.app.Parameter(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-parameter-types}}
# {{docs-fragment runoutput-example}}
# Delayed parameters with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app parameter
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed parameters with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
parameters=[
flyte.app.Parameter(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-parameters-examples.py*
The endpoint URL will be injected as the parameter value when the app starts.
This is particularly useful when you want to chain apps together (for example, a frontend app calling a backend app), without hardcoding URLs.
## Overriding parameters at serve time
You can override parameter values when serving apps (this is not supported for deployment):
```python
# Override parameters when serving
app = flyte.with_servecontext(
input_values={"my-app": {"model_path": "s3://bucket/new-model.pkl"}}
).serve(app_env)
```
> [!NOTE]
> Parameter overrides are only available when using `flyte.serve()` or `flyte.with_servecontext().serve()`.
> The `flyte.deploy()` function does not support parameter overrides - parameters must be specified in the `AppEnvironment` definition.
This is useful for:
- Testing different configurations during development
- Using different models or data sources for testing
- A/B testing different app configurations
## Example: FastAPI app with configurable model
Here's a complete example showing how to use parameters in a FastAPI app:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "fastapi",
# "uvicorn",
# "joblib",
# "scikit-learn",
# "flyte>=2.0.0b52",
# ]
# ///
from contextlib import asynccontextmanager
from pathlib import Path
import flyte
import flyte.app
import flyte.io
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# {{docs-fragment model-serving-api}}
image = flyte.Image.from_uv_script(__file__, name="app-parameters-fastapi-example")
task_env = flyte.TaskEnvironment(
name="model_serving_task",
image=image,
resources=flyte.Resources(cpu=2, memory="1Gi"),
cache="auto",
)
@task_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
import joblib
import sklearn.ensemble
import sklearn.datasets
X, y = sklearn.datasets.make_classification(n_samples=1000, n_features=5, n_classes=2, random_state=42)
model = sklearn.ensemble.RandomForestClassifier()
model.fit(X, y)
model_dir = Path("/tmp/model")
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / "model.joblib"
joblib.dump(model, model_path)
return await flyte.io.File.from_local(model_path)
state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
import joblib
model = joblib.load("/root/models/model.joblib")
state["model"] = model
yield
app = FastAPI(lifespan=lifespan)
app_env = FastAPIAppEnvironment(
name="model-serving-api",
app=app,
parameters=[
flyte.app.Parameter(
name="model_file",
# this is a placeholder
value=flyte.io.File.from_existing_remote("s3://bucket/models/default.pkl"),
mount="/root/models/",
download=True,
),
],
image=image,
resources=flyte.Resources(cpu=2, memory="2Gi"),
requires_auth=False,
)
@app.post("/predict")
async def predict(data: list[float]) -> dict[str, list[float]]:
model = state["model"]
return {"prediction": model.predict([data]).tolist()}
if __name__ == "__main__":
import logging
flyte.init_from_config(log_level=logging.DEBUG)
run = flyte.run(train_model_task)
print(f"Run: {run.url}")
run.wait()
model_file = run.outputs()[0]
print(f"Model file: {model_file.path}")
app = flyte.with_servecontext(
parameter_values={
"model-serving-api": {
"model_file": flyte.io.File.from_existing_remote(model_file.path)
}
}
).serve(app_env)
print(f"API URL: {app.url}")
# {{/docs-fragment model-serving-api}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-parameters-fastapi-example.py*
## Example: Using RunOutput for model serving
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "scikit-learn",
# "joblib",
# ]
# ///
"""Examples showing different ways to pass parameters into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-parameter-types}}
# String parameters
app_env = flyte.app.AppEnvironment(
name="configurable-app",
parameters=[
flyte.app.Parameter(name="environment", value="production"),
flyte.app.Parameter(name="log_level", value="INFO"),
],
# ...
)
# File parameters
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory parameters
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
parameters=[
flyte.app.Parameter(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-parameter-types}}
# {{docs-fragment runoutput-example}}
# Delayed parameters with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app parameter
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed parameters with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
parameters=[
flyte.app.Parameter(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
parameters=[
flyte.app.Parameter(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-parameters-examples.py*
## Accessing parameters in your app
How you access parameters depends on how they're configured:
1. **Environment variables**: If `env_var` is specified, the parameter is available as an environment variable
2. **Mounted paths**: File and directory parameters are mounted at the specified path
3. **Flyte SDK**: Use the Flyte SDK to access parameter values programmatically
```python
import os
# Parameter with env_var specified
env = flyte.app.AppEnvironment(
name="my-app",
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/model.pkl"),
mount="/app/models/model.pkl",
env_var="MODEL_PATH",
),
],
# ...
)
# Access in the app via the environment variable
API_KEY = os.getenv("API_KEY")
# Access in the app via the mounted path
with open("/app/models/model.pkl", "rb") as f:
model = pickle.load(f)
# Access in the app via the Flyte SDK (for string parameters)
parameter_value = flyte.app.get_parameter("model_file") # Returns string value
```
## Best practices
1. **Use delayed parameters**: Leverage `RunOutput` and `AppEndpoint` to create app dependencies between tasks and apps, or app-to-app chains.
2. **Override for testing**: Use the `input_values` parameter when serving to test different configurations without changing code.
3. **Mount paths clearly**: Use descriptive mount paths for file/directory parameters so your app code is easy to understand.
4. **Use environment variables**: For simple constants that you can hard-code, use `env_var` to inject values as environment variables.
5. **Production deployments**: For production, define parameters in the `AppEnvironment` rather than overriding them at deploy time.
## Limitations
- Large files/directories can slow down app startup.
- Parameter overrides are only available when using `flyte.with_servecontext(...).serve(...)`.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/auto-scaling-apps ===
## Autoscaling apps
Flyte apps support autoscaling, allowing them to scale up and down based on traffic. This helps optimize costs by scaling down when there's no traffic and scaling up when needed.
### Scaling configuration
The `scaling` parameter uses a `[[Scaling]]` object to configure autoscaling behavior:
```python
scaling=flyte.app.Scaling(
replicas=(min_replicas, max_replicas),
scaledown_after=idle_ttl_seconds,
)
```
#### Parameters
- **`replicas`**: A tuple `(min_replicas, max_replicas)` specifying the minimum and maximum number of replicas.
- **`scaledown_after`**: Time in seconds to wait before scaling down when idle (idle TTL).
### Basic scaling example
Here's a simple example with scaling from 0 to 1 replica:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py*
This configuration:
- Starts with 0 replicas (no running instances)
- Scales up to 1 replica when there's traffic
- Scales back down to 0 after 5 minutes (300 seconds) of no traffic
### Scaling patterns
#### Always-on app
For apps that need to always be running:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py*
#### Scale-to-zero app
For apps that can scale to zero when idle:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py*
#### High-availability app
For apps that need multiple replicas for availability:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py*
#### Burstable app
For apps with variable load:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py*
### Idle TTL (Time To Live)
The `scaledown_after` parameter (idle TTL) determines how long an app instance can be idle before it's scaled down.
#### Considerations
- **Too short**: May cause frequent scale up/down cycles, leading to cold starts.
- **Too long**: Keeps resources running unnecessarily, increasing costs.
- **Optimal**: Balance between cost and user experience.
#### Common idle TTL values
- **Development/Testing**: 60-180 seconds (1-3 minutes) - quick scale down for cost savings.
- **Production APIs**: 300-600 seconds (5-10 minutes) - balance cost and responsiveness.
- **Batch processing**: 900-1800 seconds (15-30 minutes) - longer to handle bursts.
- **Always-on**: Set `min_replicas > 0` - never scale down.
### Autoscaling best practices
1. **Start conservative**: Begin with longer idle TTL values and adjust based on usage.
2. **Monitor cold starts**: Track how long it takes for your app to become ready after scaling up.
3. **Consider costs**: Balance idle TTL between cost savings and user experience.
4. **Use appropriate min replicas**: Set `min_replicas > 0` for critical apps that need to be always available.
5. **Test scaling behavior**: Verify your app handles scale up/down correctly (for example, state management and connections).
### Autoscaling limitations
- Scaling is based on traffic/request patterns, not CPU/memory utilization.
- Cold starts may occur when scaling from zero.
- Stateful apps need careful design to handle scaling (use external state stores).
- Maximum replicas are limited by your cluster capacity.
### Autoscaling troubleshooting
**App scales down too quickly:**
- Increase `scaledown_after` value.
- Set `min_replicas > 0` if the app needs to stay warm.
**App doesn't scale up fast enough:**
- Ensure your cluster has capacity.
- Check if there are resource constraints.
**Cold starts are too slow:**
- Pre-warm with `min_replicas = 1`.
- Optimize app startup time.
- Consider using faster storage for model loading.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/apps-depending-on-environments ===
# Apps depending on other environments
The `depends_on` parameter allows you to specify that one app depends on another app (or task environment). When you deploy an app with `depends_on`, Flyte ensures that all dependencies are deployed first.
## Basic usage
Use `depends_on` to specify a list of environments that this app depends on:
```python
app1_env = flyte.app.AppEnvironment(name="backend-api", ...)
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[app1_env], # Ensure backend-api is deployed first
# ...
)
```
When you deploy `app2_env`, Flyte will:
1. First deploy `app1_env` (if not already deployed)
2. Then deploy `app2_env`
3. Make sure `app1_env` is available before `app2_env` starts
## Example: App calling another app
Here's a complete example where one FastAPI app calls another:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "httpx",
# ]
# ///
"""Example of one app calling another app."""
import httpx
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
)
# {{docs-fragment backend-app}}
app1 = FastAPI(
title="App 1",
description="A FastAPI app that runs some computations",
)
env1 = FastAPIAppEnvironment(
name="app1-is-called-by-app2",
app=app1,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment backend-app}}
# {{docs-fragment frontend-app}}
app2 = FastAPI(
title="App 2",
description="A FastAPI app that proxies requests to another FastAPI app",
)
env2 = FastAPIAppEnvironment(
name="app2-calls-app1",
app=app2,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
depends_on=[env1], # Depends on backend-api
)
# {{/docs-fragment frontend-app}}
# {{docs-fragment backend-endpoint}}
@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
# {{/docs-fragment backend-endpoint}}
# {{docs-fragment frontend-endpoints}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
return env1.endpoint # Access the backend endpoint
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
"""Proxy that calls the backend app."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{env1.endpoint}/greeting/{name}")
response.raise_for_status()
return response.json()
# {{/docs-fragment frontend-endpoints}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
deployments = flyte.deploy(env2)
print(f"Deployed FastAPI app: {deployments[0].env_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/app_calling_app.py*
When you deploy `env2`, Flyte will:
1. Deploy `env1` first (backend-api)
2. Wait for `env1` to be ready
3. Deploy `env2` (frontend-api)
4. `env2` can then access `env1.endpoint` to make requests
## Dependency chain
You can create chains of dependencies:
```python
app1_env = flyte.app.AppEnvironment(name="service-1", ...)
app2_env = flyte.app.AppEnvironment(name="service-2", depends_on=[app1_env], ...)
app3_env = flyte.app.AppEnvironment(name="service-3", depends_on=[app2_env], ...)
# Deploying app3_env will deploy in order: app1_env -> app2_env -> app3_env
```
## Multiple dependencies
An app can depend on multiple environments:
```python
backend_env = flyte.app.AppEnvironment(name="backend", ...)
database_env = flyte.app.AppEnvironment(name="database", ...)
api_env = flyte.app.AppEnvironment(
name="api",
depends_on=[backend_env, database_env], # Depends on both
# ...
)
```
When deploying `api_env`, both `backend_env` and `database_env` will be deployed first (they may be deployed in parallel if they don't depend on each other).
## Using AppEndpoint for dependency URLs
When one app depends on another, you can use `AppEndpoint` to get the URL:
```python
backend_env = flyte.app.AppEnvironment(name="backend-api", ...)
frontend_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[backend_env],
parameters=[
flyte.app.Parameter(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
),
],
# ...
)
```
The `backend_url` parameter will be automatically set to the backend app's endpoint URL.
You can get this value in your app code using `flyte.app.get_input("backend_url")`.
## Deployment behavior
When deploying with `flyte.deploy()`:
```python
# Deploy the app (dependencies are automatically deployed)
deployments = flyte.deploy(env2)
# All dependencies are included in the deployment plan
for deployment in deployments:
print(f"Deployed: {deployment.env.name}")
```
Flyte will:
1. Build a deployment plan that includes all dependencies
2. Deploy dependencies in the correct order
3. Ensure dependencies are ready before deploying dependent apps
## Task environment dependencies
You can also depend on task environments:
```python
task_env = flyte.TaskEnvironment(name="training-env", ...)
serving_env = flyte.app.AppEnvironment(
name="serving-app",
depends_on=[task_env], # Can depend on task environments too
# ...
)
```
This ensures the task environment is available when the app is deployed (useful if the app needs to call tasks in that environment).
## Best practices
1. **Explicit dependencies**: Always use `depends_on` to make app dependencies explicit
2. **Circular dependencies**: Avoid circular dependencies (app A depends on B, B depends on A)
3. **Dependency order**: Design your dependency graph to be a DAG (Directed Acyclic Graph)
4. **Endpoint access**: Use `AppEndpoint` to pass dependency URLs as inputs
5. **Document dependencies**: Make sure your app documentation explains its dependencies
## Example: A/B testing with dependencies
Here's an example of an A/B testing setup where a root app depends on two variant apps:
```python
app_a = FastAPI(title="Variant A")
app_b = FastAPI(title="Variant B")
root_app = FastAPI(title="Root App")
env_a = FastAPIAppEnvironment(name="app-a-variant", app=app_a, ...)
env_b = FastAPIAppEnvironment(name="app-b-variant", app=app_b, ...)
env_root = FastAPIAppEnvironment(
name="root-ab-testing-app",
app=root_app,
depends_on=[env_a, env_b], # Depends on both variants
# ...
)
```
The root app can route traffic to either variant A or B based on A/B testing logic, and both variants will be deployed before the root app starts.
## Limitations
- Circular dependencies are not supported
- Dependencies must be in the same project/domain
- Dependency deployment order is deterministic but dependencies at the same level may deploy in parallel
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps ===
# Build apps
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This section covers how to build different types of apps with Flyte, including Streamlit dashboards, FastAPI REST APIs, vLLM and SGLang model servers, webhooks, and WebSocket applications.
> [!TIP]
> Go to **Core concepts > Apps** for an overview of apps and a quick example.
## App types
Flyte supports various types of apps:
- **UI dashboard apps**: Interactive web dashboards and data visualization tools like Streamlit and Gradio
- **Web API apps**: REST APIs, webhooks, and backend services like FastAPI and Flask
- **Model serving apps**: High-performance LLM serving with vLLM and SGLang
## Next steps
- **Build apps > Single-script apps**: The simplest way to build and deploy apps in a single Python script
- **Build apps > Multi-script apps**: Build FastAPI and Streamlit apps with multiple files
- **Build apps > App usage patterns**: Call apps from tasks, tasks from apps, and apps from apps
- **Build apps > Secret-based authentication**: Authenticate FastAPI apps using Flyte secrets
- **Build apps > Streamlit app**: Build interactive Streamlit dashboards
- **Build apps > FastAPI app**: Create REST APIs and backend services
- **Build apps > vLLM app**: Serve large language models with vLLM
- **Build apps > SGLang app**: Serve LLMs with SGLang for structured generation
## Subpages
- **Build apps > Single-script apps**
- **Build apps > Multi-script apps**
- **Build apps > App usage patterns**
- **Build apps > Secret-based authentication**
- **Build apps > Streamlit app**
- **Build apps > FastAPI app**
- **Build apps > vLLM app**
- **Build apps > SGLang app**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/single-script-apps ===
# Single-script apps
The simplest way to build and deploy an app with Flyte is to write everything in a single Python script. This approach is perfect for:
- **Quick prototypes**: Rapidly test ideas and concepts
- **Simple services**: Basic HTTP servers, APIs, or dashboards
- **Learning**: Understanding how Flyte apps work without complexity
- **Minimal examples**: Demonstrating core functionality
All the code for your appβthe application logic, the app environment configuration, and the deployment codeβlives in one file. This makes it easy to understand, share, and deploy.
## Plain Python HTTP server
The simplest possible app is a plain Python HTTP server using Python's built-in `http.server` module. This requires no external dependencies beyond the Flyte SDK.
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""A plain Python HTTP server example - the simplest possible app."""
import flyte
import flyte.app
from pathlib import Path
# {{docs-fragment server-code}}
# Create a simple HTTP server handler
from http.server import HTTPServer, BaseHTTPRequestHandler
class SimpleHandler(BaseHTTPRequestHandler):
"""A simple HTTP server handler."""
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"
Hello from Plain Python Server!
")
elif self.path == "/health":
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(b'{"status": "healthy"}')
else:
self.send_response(404)
self.end_headers()
# {{/docs-fragment server-code}}
# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="plain-python-server",
image=flyte.Image.from_debian_base(python_version=(3, 12)),
args=["python", file_name, "--server"],
port=8080,
resources=flyte.Resources(cpu="1", memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import sys
if "--server" in sys.argv:
server = HTTPServer(("0.0.0.0", 8080), SimpleHandler)
print("Server running on port 8080")
server.serve_forever()
else:
flyte.init_from_config(root_dir=Path(__file__).parent)
app = flyte.serve(app_env)
print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/plain_python_server.py*
**Key points**
- **No external dependencies**: Uses only Python's standard library
- **Simple handler**: Define request handlers as Python classes
- **Basic command**: Run the server with a simple Python command
- **Minimal resources**: Requires only basic CPU and memory
## Streamlit app
Streamlit makes it easy to build interactive web dashboards. Here's a complete single-script Streamlit app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "streamlit",
# ]
# ///
"""A single-script Streamlit app example."""
import pathlib
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Simple Streamlit App", page_icon="π")
st.title("Hello from Streamlit!")
st.write("This is a simple single-script Streamlit app.")
name = st.text_input("What's your name?", "World")
st.write(f"Hello, {name}!")
if st.button("Click me!"):
st.balloons()
st.success("Button clicked!")
# {{/docs-fragment streamlit-app}}
# {{docs-fragment app-env}}
file_name = pathlib.Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="streamlit-single-script",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1"
),
args=["streamlit", "run", file_name, "--server.port", "8080", "--", "--server"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(app_env)
print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit_single_script.py*
**Key points**
- **Interactive UI**: Streamlit provides widgets and visualizations out of the box
- **Single file**: All UI logic and deployment code in one script
- **Simple deployment**: Just specify the Streamlit command and port
- **Rich ecosystem**: Access to Streamlit's extensive component library
## FastAPI app
FastAPI is a modern, fast web framework for building APIs. Here's a minimal single-script FastAPI app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A single-script FastAPI app example - the simplest FastAPI app."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = FastAPI(
title="Simple FastAPI App",
description="A minimal single-script FastAPI application",
version="1.0.0",
)
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health():
return {"status": "healthy"}
# {{/docs-fragment fastapi-app}}
# {{docs-fragment app-env}}
app_env = FastAPIAppEnvironment(
name="fastapi-single-script",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.serve(app_env)
print(f"Deployed: {app_deployment.url}")
print(f"API docs: {app_deployment.url}/docs")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi_single_script.py*
**Key points**
- **FastAPIAppEnvironment**: Automatically configures uvicorn and FastAPI
- **Type hints**: FastAPI uses Python type hints for automatic validation
- **Auto docs**: Interactive API documentation at `/docs` endpoint
- **Async support**: Built-in support for async/await patterns
## Running single-script apps
To run any of these examples:
1. **Save the script** to a file (e.g., `my_app.py`)
2. **Ensure you have a config file** (`./.flyte/config.yaml` or `./config.yaml`)
3. **Run the script**:
```bash
python my_app.py
```
Or using `uv`:
```bash
uv run my_app.py
```
The script will:
- Initialize Flyte from your config
- Deploy the app to your Union/Flyte instance
- Print the app URL
## When to use single-script apps
**Use single-script apps when:**
- Building prototypes or proof-of-concepts
- Creating simple services with minimal logic
- Learning how Flyte apps work
- Sharing complete, runnable examples
- Building demos or tutorials
**Consider multi-script apps when:**
- Your app grows beyond a few hundred lines
- You need to organize code into modules
- You want to reuse components across apps
- You're building production applications
See [**Multi-script apps**](./multi-script-apps) for examples of organizing apps across multiple files.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/multi-script-apps ===
# Multi-script apps
Real-world applications often span multiple files. This page shows how to build FastAPI and Streamlit apps with multiple Python files.
## FastAPI multi-script app
### Project structure
```
project/
βββ app.py # Main FastAPI app file
βββ module.py # Helper module
```
### Example: Multi-file FastAPI app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py*
```
# {{docs-fragment helper-function}}
def function():
"""Helper function used by the FastAPI app."""
return {"message": "Hello from module.py!"}
# {{/docs-fragment helper-function}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/module.py*
### Automatic file discovery
`FastAPIAppEnvironment` automatically discovers and includes the necessary files by analyzing your imports. However, if you have files that aren't automatically detected (like configuration files or data files), you can explicitly include them:
```python
app_env = FastAPIAppEnvironment(
name="fastapi-with-config",
app=app,
include=["app.py", "module.py", "config.yaml"], # Explicit includes
# ...
)
```
## Streamlit multi-script app
### Project structure
```
project/
βββ main.py # Main Streamlit app
βββ utils.py # Utility functions
βββ components.py # Reusable components
```
### Example: Multi-file Streamlit app
```
import os
import streamlit as st
from utils import generate_data
# {{docs-fragment streamlit-app}}
all_columns = ["Apples", "Orange", "Pineapple"]
with st.container(border=True):
columns = st.multiselect("Columns", all_columns, default=all_columns)
all_data = st.cache_data(generate_data)(columns=all_columns, seed=101)
data = all_data[columns]
tab1, tab2 = st.tabs(["Chart", "Dataframe"])
tab1.line_chart(data, height=250)
tab2.dataframe(data, height=250, use_container_width=True)
st.write(f"Environment: {os.environ}")
# {{/docs-fragment streamlit-app}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/main.py*
```
import numpy as np
import pandas as pd
# {{docs-fragment utils-function}}
def generate_data(columns: list[str], seed: int = 42):
rng = np.random.default_rng(seed)
data = pd.DataFrame(rng.random(size=(20, len(columns))), columns=columns)
return data
# {{/docs-fragment utils-function}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/utils.py*
### Deploying multi-file Streamlit app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""A custom Streamlit app with multiple files."""
import pathlib
import flyte
import flyte.app
# {{docs-fragment app-env}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
app_env = flyte.app.AppEnvironment(
name="streamlit-multi-file-app",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py"], # Include your app files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.deploy(app_env)
print(f"Deployed app: {app[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/multi_file_streamlit.py*
## Complex multi-file example
Here's a more complex example with multiple modules:
### Project structure
```
project/
βββ app.py
βββ models/
β βββ __init__.py
β βββ user.py
βββ services/
β βββ __init__.py
β βββ auth.py
βββ utils/
βββ __init__.py
βββ helpers.py
```
### Example code
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Complex multi-file FastAPI app example."""
from pathlib import Path
from fastapi import FastAPI
from models.user import User
from services.auth import authenticate
from utils.helpers import format_response
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment complex-app}}
app = FastAPI(title="Complex Multi-file App")
@app.get("/users/{user_id}")
async def get_user(user_id: int):
user = User(id=user_id, name="John Doe")
return format_response(user)
# {{/docs-fragment complex-app}}
# {{docs-fragment complex-env}}
app_env = FastAPIAppEnvironment(
name="complex-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"pydantic",
),
# Include all necessary files
include=[
"app.py",
"models/",
"services/",
"utils/",
],
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment complex-env}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/app.py*
```
# {{docs-fragment user-model}}
from pydantic import BaseModel
class User(BaseModel):
id: int
name: str
# {{/docs-fragment user-model}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/models/user.py*
```
# {{docs-fragment auth-service}}
def authenticate(token: str) -> bool:
"""Authenticate a user by token."""
# ... authentication logic ...
return True
# {{/docs-fragment auth-service}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/services/auth.py*
```
# {{docs-fragment helpers}}
def format_response(data):
"""Format a response with standard structure."""
return {"data": data, "status": "success"}
# {{/docs-fragment helpers}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/utils/helpers.py*
### Deploying complex app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Complex multi-file FastAPI app example."""
from pathlib import Path
from fastapi import FastAPI
from models.user import User
from services.auth import authenticate
from utils.helpers import format_response
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment complex-app}}
app = FastAPI(title="Complex Multi-file App")
@app.get("/users/{user_id}")
async def get_user(user_id: int):
user = User(id=user_id, name="John Doe")
return format_response(user)
# {{/docs-fragment complex-app}}
# {{docs-fragment complex-env}}
app_env = FastAPIAppEnvironment(
name="complex-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"pydantic",
),
# Include all necessary files
include=[
"app.py",
"models/",
"services/",
"utils/",
],
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment complex-env}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/app.py*
## Best practices
1. **Use explicit includes**: For Streamlit apps, explicitly list all files in `include`
2. **Automatic discovery**: For FastAPI apps, `FastAPIAppEnvironment` handles most cases automatically
3. **Organize modules**: Use proper Python package structure with `__init__.py` files
4. **Test locally**: Test your multi-file app locally before deploying
5. **Include all dependencies**: Include all files that your app imports
## Troubleshooting
**Import errors:**
- Verify all files are included in the `include` parameter
- Check that file paths are correct (relative to app definition file)
- Ensure `__init__.py` files are included for packages
**Module not found:**
- Add missing files to the `include` list
- Check that import paths match the file structure
- Verify that the image includes all necessary packages
**File not found at runtime:**
- Ensure all referenced files are included
- Check mount paths for file/directory inputs
- Verify file paths are relative to the app root directory
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/app-usage-patterns ===
# App usage patterns
Apps and tasks can interact in various ways: calling each other via HTTP, webhooks, WebSockets, or direct browser usage. This page describes the different patterns and when to use them.
## Patterns overview
1. **Build apps > App usage patterns > Call app from task**: A task makes HTTP requests to an app
2. **Build apps > App usage patterns > Call task from app (webhooks / APIs)**: An app triggers task execution via the Flyte SDK
3. **Build apps > App usage patterns > Call app from app**: One app makes HTTP requests to another app
4. **Build apps > App usage patterns > WebSocket-based patterns**: Real-time, bidirectional communication
5. **Browser-based access**: Users access apps directly through the browser
## Call app from task
Tasks can call apps by making HTTP requests to the app's endpoint. This is useful when:
- You need to use a long-running service during task execution
- You want to call a model serving endpoint from a batch processing task
- You need to interact with an API from a workflow
### Example: Task calling an app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "httpx",
# ]
# ///
"""Example of a task calling an app."""
import pathlib
import httpx
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(title="Add One", description="Adds one to the input", version="1.0.0")
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{docs-fragment app-definition}}
app_env = FastAPIAppEnvironment(
name="add-one-app",
app=app,
description="Adds one to the input",
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-definition}}
# {{docs-fragment task-env}}
task_env = flyte.TaskEnvironment(
name="add_one_task_env",
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
depends_on=[app_env], # Ensure app is deployed before task runs
)
# {{/docs-fragment task-env}}
# {{docs-fragment app-endpoint}}
@app.get("/")
async def add_one(x: int) -> dict[str, int]:
"""Main endpoint for the add-one app."""
return {"result": x + 1}
# {{/docs-fragment app-endpoint}}
# {{docs-fragment task}}
@task_env.task
async def add_one_task(x: int) -> int:
print(f"Calling app at {app_env.endpoint}")
async with httpx.AsyncClient() as client:
response = await client.get(app_env.endpoint, params={"x": x})
response.raise_for_status()
return response.json()["result"]
# {{/docs-fragment task}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
deployments = flyte.deploy(task_env)
print(f"Deployed task environment: {deployments}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/task_calling_app.py*
Key points:
- The task environment uses `depends_on=[app_env]` to ensure the app is deployed first
- Access the app endpoint via `app_env.endpoint`
- Use standard HTTP client libraries (like `httpx`) to make requests
## Call task from app (webhooks / APIs)
Apps can trigger task execution using the Flyte SDK. This is useful for:
- Webhooks that trigger workflows
- APIs that need to run batch jobs
- Services that need to execute tasks asynchronously
Webhooks are HTTP endpoints that trigger actions in response to external events. Flyte apps can serve as webhook endpoints that trigger task runs, workflows, or other operations.
### Example: Basic webhook app
Here's a simple webhook that triggers Flyte tasks:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A webhook that triggers Flyte tasks."""
import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment auth}}
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if credentials.credentials != WEBHOOK_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
# {{/docs-fragment auth}}
# {{docs-fragment lifespan}}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
# Cleanup if needed
# {{/docs-fragment lifespan}}
# {{docs-fragment app}}
app = FastAPI(
title="Flyte Webhook Runner",
description="A webhook service that triggers Flyte task runs",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
# {{/docs-fragment app}}
# {{docs-fragment webhook-endpoint}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
project: str,
domain: str,
name: str,
version: str,
inputs: dict,
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
"""
Trigger a Flyte task run via webhook.
Returns information about the launched run.
"""
# Fetch the task
task = remote.Task.get(
project=project,
domain=domain,
name=name,
version=version,
)
# Run the task
run = await flyte.run.aio(task, **inputs)
return {
"url": run.url,
"id": run.id,
"status": "started",
}
# {{/docs-fragment webhook-endpoint}}
# {{docs-fragment env}}
env = FastAPIAppEnvironment(
name="webhook-runner",
app=app,
description="A webhook service that triggers Flyte task runs",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False, # We handle auth in the app
env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
# {{/docs-fragment env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/webhook/basic_webhook.py*
Once deployed, you can trigger tasks via HTTP POST:
```bash
curl -X POST "https://your-webhook-url/run-task/flytesnacks/development/my_task/v1" \
-H "Authorization: Bearer test-api-key" \
-H "Content-Type: application/json" \
-d '{"input_key": "input_value"}'
```
Response:
```json
{
"url": "https://console.union.ai/...",
"id": "abc123",
"status": "started"
}
```
### Advanced webhook patterns
**Webhook with validation**
Use Pydantic for input validation:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A webhook with Pydantic validation."""
import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
from pydantic import BaseModel
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if credentials.credentials != WEBHOOK_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
app = FastAPI(
title="Flyte Webhook Runner with Validation",
description="A webhook service that triggers Flyte task runs with Pydantic validation",
version="1.0.0",
lifespan=lifespan,
)
# {{docs-fragment validation-model}}
class TaskInput(BaseModel):
data: dict
priority: int = 0
# {{/docs-fragment validation-model}}
# {{docs-fragment validated-webhook}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
project: str,
domain: str,
name: str,
version: str,
inputs: TaskInput, # Validated input
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
task = remote.Task.get(
project=project,
domain=domain,
name=name,
version=version,
)
run = await flyte.run.aio(task, **inputs.model_dump())
return {
"run_id": run.id,
"url": run.url,
}
# {{/docs-fragment validated-webhook}}
env = FastAPIAppEnvironment(
name="webhook-with-validation",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_validation.py*
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A webhook with Pydantic validation."""
import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
from pydantic import BaseModel
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if credentials.credentials != WEBHOOK_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
app = FastAPI(
title="Flyte Webhook Runner with Validation",
description="A webhook service that triggers Flyte task runs with Pydantic validation",
version="1.0.0",
lifespan=lifespan,
)
# {{docs-fragment validation-model}}
class TaskInput(BaseModel):
data: dict
priority: int = 0
# {{/docs-fragment validation-model}}
# {{docs-fragment validated-webhook}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
project: str,
domain: str,
name: str,
version: str,
inputs: TaskInput, # Validated input
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
task = remote.Task.get(
project=project,
domain=domain,
name=name,
version=version,
)
run = await flyte.run.aio(task, **inputs.model_dump())
return {
"run_id": run.id,
"url": run.url,
}
# {{/docs-fragment validated-webhook}}
env = FastAPIAppEnvironment(
name="webhook-with-validation",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_validation.py*
**Webhook with response waiting**
Wait for task completion:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A webhook that waits for task completion."""
import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if credentials.credentials != WEBHOOK_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
app = FastAPI(
title="Flyte Webhook Runner (Wait for Completion)",
description="A webhook service that triggers Flyte task runs and waits for completion",
version="1.0.0",
lifespan=lifespan,
)
# {{docs-fragment wait-webhook}}
@app.post("/run-task-and-wait/{project}/{domain}/{name}/{version}")
async def run_task_and_wait(
project: str,
domain: str,
name: str,
version: str,
inputs: dict,
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
task = remote.Task.get(
project=project,
domain=domain,
name=name,
version=version,
)
run = await flyte.run.aio(task, **inputs)
run.wait() # Wait for completion
return {
"run_id": run.id,
"url": run.url,
"status": run.status,
"outputs": run.outputs(),
}
# {{/docs-fragment wait-webhook}}
env = FastAPIAppEnvironment(
name="webhook-wait-completion",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_wait.py*
**Webhook with secret management**
Use Flyte secrets for API keys:
```python
env = FastAPIAppEnvironment(
name="webhook-runner",
app=app,
secrets=flyte.Secret(key="webhook-api-key", as_env_var="WEBHOOK_API_KEY"),
# ...
)
```
Then access in your app:
```python
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
```
### Webhook security and best practices
- **Authentication**: Always secure webhooks with authentication (API keys, tokens, etc.).
- **Input validation**: Validate webhook inputs using Pydantic models.
- **Error handling**: Handle errors gracefully and return meaningful error messages.
- **Async operations**: Use async/await for I/O operations.
- **Health checks**: Include health check endpoints.
- **Logging**: Log webhook requests for debugging and auditing.
- **Rate limiting**: Consider implementing rate limiting for production.
Security considerations:
- Store API keys in Flyte secrets, not in code.
- Always use HTTPS in production.
- Validate all inputs to prevent injection attacks.
- Implement proper access control mechanisms.
- Log all webhook invocations for security auditing.
### Example: GitHub webhook
Here's an example webhook that triggers tasks based on GitHub events:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A GitHub webhook that triggers Flyte tasks based on GitHub events."""
import pathlib
import hmac
import hashlib
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Header, HTTPException
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
app = FastAPI(
title="GitHub Webhook Handler",
description="Triggers Flyte tasks based on GitHub events",
version="1.0.0",
lifespan=lifespan,
)
# {{docs-fragment github-webhook}}
@app.post("/github-webhook")
async def github_webhook(
request: Request,
x_hub_signature_256: str = Header(None),
):
"""Handle GitHub webhook events."""
body = await request.body()
# Verify signature
secret = os.getenv("GITHUB_WEBHOOK_SECRET")
signature = hmac.new(
secret.encode(),
body,
hashlib.sha256
).hexdigest()
expected_signature = f"sha256={signature}"
if not hmac.compare_digest(x_hub_signature_256, expected_signature):
raise HTTPException(status_code=403, detail="Invalid signature")
# Process webhook
event = await request.json()
event_type = request.headers.get("X-GitHub-Event")
if event_type == "push":
# Trigger deployment task
task = remote.Task.get(
project="my-project",
domain="development",
name="deploy-task",
version="v1",
)
run = await flyte.run.aio(task, commit=event["after"])
return {"run_id": run.id, "url": run.url}
return {"status": "ignored"}
# {{/docs-fragment github-webhook}}
# {{docs-fragment env}}
env = FastAPIAppEnvironment(
name="github-webhook",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
secrets=flyte.Secret(key="GITHUB_WEBHOOK_SECRET", as_env_var="GITHUB_WEBHOOK_SECRET"),
)
# {{/docs-fragment env}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed GitHub webhook: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/github_webhook.py*
### Gradio agent UI
For AI agents, a Gradio app lets you build an interactive UI that kicks off agent runs. The app uses `flyte.with_runcontext()` to run the agent task either locally or on a remote cluster, controlled by an environment variable.
```python
import os
import flyte
import flyte.app
from research_agent import agent
RUN_MODE = os.getenv("RUN_MODE", "remote")
serving_env = flyte.app.AppEnvironment(
name="research-agent-ui",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"gradio", "langchain-core", "langchain-openai", "langgraph",
),
secrets=flyte.Secret(key="OPENAI_API_KEY", as_env_var="OPENAI_API_KEY"),
port=7860,
)
def run_query(request: str):
"""Kick off the agent as a Flyte task."""
result = flyte.with_runcontext(mode=RUN_MODE).run(agent, request=request)
result.wait()
return result.outputs()[0]
@serving_env.server
def app_server():
create_demo().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
create_demo().launch()
```
The `RUN_MODE` variable gives you a smooth development progression:
1. **Fully local**: `RUN_MODE=local python agent_app.py`. Everything runs in your local Python environment, great for rapid iteration.
2. **Local app, remote task**: `python agent_app.py`. The UI runs locally but the agent executes on the cluster with full compute resources.
3. **Full remote**: `flyte deploy agent_app.py serving_env`. Both the UI and agent run on the cluster.
## Call app from app
Apps can call other apps by making HTTP requests. This is useful for:
- Microservice architectures
- Proxy/gateway patterns
- A/B testing setups
- Service composition
### Example: App calling another app
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "httpx",
# ]
# ///
"""Example of one app calling another app."""
import httpx
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
)
# {{docs-fragment backend-app}}
app1 = FastAPI(
title="App 1",
description="A FastAPI app that runs some computations",
)
env1 = FastAPIAppEnvironment(
name="app1-is-called-by-app2",
app=app1,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment backend-app}}
# {{docs-fragment frontend-app}}
app2 = FastAPI(
title="App 2",
description="A FastAPI app that proxies requests to another FastAPI app",
)
env2 = FastAPIAppEnvironment(
name="app2-calls-app1",
app=app2,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
depends_on=[env1], # Depends on backend-api
)
# {{/docs-fragment frontend-app}}
# {{docs-fragment backend-endpoint}}
@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
# {{/docs-fragment backend-endpoint}}
# {{docs-fragment frontend-endpoints}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
return env1.endpoint # Access the backend endpoint
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
"""Proxy that calls the backend app."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{env1.endpoint}/greeting/{name}")
response.raise_for_status()
return response.json()
# {{/docs-fragment frontend-endpoints}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
deployments = flyte.deploy(env2)
print(f"Deployed FastAPI app: {deployments[0].env_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/app_calling_app.py*
Key points:
- Use `depends_on=[env1]` to ensure dependencies are deployed first
- Access the app endpoint via `env1.endpoint`
- Use HTTP clients (like `httpx`) to make requests between apps
### Using AppEndpoint parameter
You can pass app endpoints as parameters for more flexibility:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "httpx",
# ]
# ///
"""Example of one app calling another app."""
import os
import httpx
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
)
# {{docs-fragment backend-app}}
app1 = FastAPI(
title="App 1",
description="A FastAPI app that runs some computations",
)
env1 = FastAPIAppEnvironment(
name="app1-is-called-by-app2",
app=app1,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
# {{/docs-fragment backend-app}}
# {{docs-fragment using-app-endpoint}}
app2 = FastAPI(
title="App 2",
description="A FastAPI app that proxies requests to another FastAPI app",
)
env2 = FastAPIAppEnvironment(
name="app2-calls-app1",
app=app2,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
depends_on=[env1], # Depends on backend-api
parameters=[
flyte.app.Parameter(
name="app1_endpoint",
value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
env_var="APP1_ENDPOINT",
),
],
)
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
app1_endpoint = os.getenv("APP1_ENDPOINT")
async with httpx.AsyncClient() as client:
response = await client.get(f"{app1_endpoint}/greeting/{name}")
response.raise_for_status()
return response.json()
# {{/docs-fragment using-app-endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
deployments = flyte.deploy(env2)
print(f"Deployed FastAPI app: {deployments[0].env_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/app_calling_app_endpoint.py*
## WebSocket-based patterns
WebSockets enable bidirectional, real-time communication between clients and servers. Flyte apps can serve WebSocket endpoints for real-time applications like chat, live updates, or streaming data.
### Example: Basic WebSocket app
Here's a simple FastAPI app with WebSocket support:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""A FastAPI app with WebSocket support."""
import pathlib
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import asyncio
import json
from datetime import UTC, datetime
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="Flyte WebSocket Demo",
description="A FastAPI app with WebSocket support",
version="1.0.0",
)
# {{docs-fragment connection-manager}}
class ConnectionManager:
"""Manages WebSocket connections."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""Accept and register a new WebSocket connection."""
await websocket.accept()
self.active_connections.append(websocket)
print(f"Client connected. Total: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
self.active_connections.remove(websocket)
print(f"Client disconnected. Total: {len(self.active_connections)}")
async def send_personal_message(self, message: str, websocket: WebSocket):
"""Send a message to a specific WebSocket connection."""
await websocket.send_text(message)
async def broadcast(self, message: str):
"""Broadcast a message to all active connections."""
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception as e:
print(f"Error broadcasting: {e}")
manager = ConnectionManager()
# {{/docs-fragment connection-manager}}
# {{docs-fragment websocket-endpoint}}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time communication."""
await manager.connect(websocket)
try:
# Send welcome message
await manager.send_personal_message(
json.dumps({
"type": "system",
"message": "Welcome! You are connected.",
"timestamp": datetime.now(UTC).isoformat(),
}),
websocket,
)
# Listen for messages
while True:
data = await websocket.receive_text()
# Echo back to sender
await manager.send_personal_message(
json.dumps({
"type": "echo",
"message": f"Echo: {data}",
"timestamp": datetime.now(UTC).isoformat(),
}),
websocket,
)
# Broadcast to all clients
await manager.broadcast(
json.dumps({
"type": "broadcast",
"message": f"Broadcast: {data}",
"timestamp": datetime.now(UTC).isoformat(),
"connections": len(manager.active_connections),
})
)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(
json.dumps({
"type": "system",
"message": "A client disconnected",
"connections": len(manager.active_connections),
})
)
# {{/docs-fragment websocket-endpoint}}
# {{docs-fragment env}}
env = FastAPIAppEnvironment(
name="websocket-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed websocket app: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/basic_websocket.py*
### WebSocket patterns
**Echo server**
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""WebSocket patterns: echo, broadcast, streaming, and chat."""
import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="WebSocket Patterns Demo",
description="Demonstrates various WebSocket patterns",
version="1.0.0",
)
# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
pass
# {{/docs-fragment echo-server}}
# Connection manager for broadcast
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception:
pass
manager = ConnectionManager()
# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}
# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Generate or fetch data
data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
await websocket.send_json(data)
await asyncio.sleep(1) # Send update every second
except WebSocketDisconnect:
pass
# {{/docs-fragment streaming-server}}
# {{docs-fragment chat-room}}
class ChatRoom:
def __init__(self, name: str):
self.name = name
self.connections: list[WebSocket] = []
async def join(self, websocket: WebSocket):
self.connections.append(websocket)
async def leave(self, websocket: WebSocket):
self.connections.remove(websocket)
async def broadcast(self, message: str, sender: WebSocket):
for connection in self.connections:
if connection != sender:
await connection.send_text(message)
rooms: dict[str, ChatRoom] = {}
@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
await websocket.accept()
if room_name not in rooms:
rooms[room_name] = ChatRoom(room_name)
room = rooms[room_name]
await room.join(websocket)
try:
while True:
data = await websocket.receive_text()
await room.broadcast(data, websocket)
except WebSocketDisconnect:
await room.leave(websocket)
# {{/docs-fragment chat-room}}
env = FastAPIAppEnvironment(
name="websocket-patterns",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*
**Broadcast server**
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""WebSocket patterns: echo, broadcast, streaming, and chat."""
import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="WebSocket Patterns Demo",
description="Demonstrates various WebSocket patterns",
version="1.0.0",
)
# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
pass
# {{/docs-fragment echo-server}}
# Connection manager for broadcast
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception:
pass
manager = ConnectionManager()
# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}
# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Generate or fetch data
data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
await websocket.send_json(data)
await asyncio.sleep(1) # Send update every second
except WebSocketDisconnect:
pass
# {{/docs-fragment streaming-server}}
# {{docs-fragment chat-room}}
class ChatRoom:
def __init__(self, name: str):
self.name = name
self.connections: list[WebSocket] = []
async def join(self, websocket: WebSocket):
self.connections.append(websocket)
async def leave(self, websocket: WebSocket):
self.connections.remove(websocket)
async def broadcast(self, message: str, sender: WebSocket):
for connection in self.connections:
if connection != sender:
await connection.send_text(message)
rooms: dict[str, ChatRoom] = {}
@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
await websocket.accept()
if room_name not in rooms:
rooms[room_name] = ChatRoom(room_name)
room = rooms[room_name]
await room.join(websocket)
try:
while True:
data = await websocket.receive_text()
await room.broadcast(data, websocket)
except WebSocketDisconnect:
await room.leave(websocket)
# {{/docs-fragment chat-room}}
env = FastAPIAppEnvironment(
name="websocket-patterns",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*
**Real-time data streaming**
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""WebSocket patterns: echo, broadcast, streaming, and chat."""
import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="WebSocket Patterns Demo",
description="Demonstrates various WebSocket patterns",
version="1.0.0",
)
# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
pass
# {{/docs-fragment echo-server}}
# Connection manager for broadcast
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception:
pass
manager = ConnectionManager()
# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}
# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Generate or fetch data
data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
await websocket.send_json(data)
await asyncio.sleep(1) # Send update every second
except WebSocketDisconnect:
pass
# {{/docs-fragment streaming-server}}
# {{docs-fragment chat-room}}
class ChatRoom:
def __init__(self, name: str):
self.name = name
self.connections: list[WebSocket] = []
async def join(self, websocket: WebSocket):
self.connections.append(websocket)
async def leave(self, websocket: WebSocket):
self.connections.remove(websocket)
async def broadcast(self, message: str, sender: WebSocket):
for connection in self.connections:
if connection != sender:
await connection.send_text(message)
rooms: dict[str, ChatRoom] = {}
@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
await websocket.accept()
if room_name not in rooms:
rooms[room_name] = ChatRoom(room_name)
room = rooms[room_name]
await room.join(websocket)
try:
while True:
data = await websocket.receive_text()
await room.broadcast(data, websocket)
except WebSocketDisconnect:
await room.leave(websocket)
# {{/docs-fragment chat-room}}
env = FastAPIAppEnvironment(
name="websocket-patterns",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*
**Chat application**
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""WebSocket patterns: echo, broadcast, streaming, and chat."""
import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="WebSocket Patterns Demo",
description="Demonstrates various WebSocket patterns",
version="1.0.0",
)
# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
pass
# {{/docs-fragment echo-server}}
# Connection manager for broadcast
class ConnectionManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception:
pass
manager = ConnectionManager()
# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}
# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Generate or fetch data
data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
await websocket.send_json(data)
await asyncio.sleep(1) # Send update every second
except WebSocketDisconnect:
pass
# {{/docs-fragment streaming-server}}
# {{docs-fragment chat-room}}
class ChatRoom:
def __init__(self, name: str):
self.name = name
self.connections: list[WebSocket] = []
async def join(self, websocket: WebSocket):
self.connections.append(websocket)
async def leave(self, websocket: WebSocket):
self.connections.remove(websocket)
async def broadcast(self, message: str, sender: WebSocket):
for connection in self.connections:
if connection != sender:
await connection.send_text(message)
rooms: dict[str, ChatRoom] = {}
@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
await websocket.accept()
if room_name not in rooms:
rooms[room_name] = ChatRoom(room_name)
room = rooms[room_name]
await room.join(websocket)
try:
while True:
data = await websocket.receive_text()
await room.broadcast(data, websocket)
except WebSocketDisconnect:
await room.leave(websocket)
# {{/docs-fragment chat-room}}
env = FastAPIAppEnvironment(
name="websocket-patterns",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*
### Using WebSockets with Flyte tasks
You can trigger Flyte tasks from WebSocket messages:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "websockets",
# ]
# ///
"""A WebSocket app that triggers Flyte tasks and streams updates."""
import json
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
app = FastAPI(
title="WebSocket Task Runner",
description="Triggers Flyte tasks via WebSocket and streams updates",
version="1.0.0",
lifespan=lifespan,
)
# {{docs-fragment task-runner-websocket}}
@app.websocket("/task-runner")
async def task_runner(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Receive task request
message = await websocket.receive_text()
request = json.loads(message)
# Trigger Flyte task
task = remote.Task.get(
project=request["project"],
domain=request["domain"],
name=request["task"],
version=request["version"],
)
run = await flyte.run.aio(task, **request["inputs"])
# Send run info back
await websocket.send_json({
"run_id": run.id,
"url": run.url,
"status": "started",
})
# Optionally stream updates
async for update in run.stream():
await websocket.send_json({
"status": update.status,
"message": update.message,
})
except WebSocketDisconnect:
pass
# {{/docs-fragment task-runner-websocket}}
env = FastAPIAppEnvironment(
name="task-runner-websocket",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed WebSocket task runner: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/task_runner_websocket.py*
### WebSocket client example
Connect from Python:
```python
import asyncio
import websockets
import json
async def client():
uri = "ws://your-app-url/ws"
async with websockets.connect(uri) as websocket:
# Send message
await websocket.send("Hello, Server!")
# Receive message
response = await websocket.recv()
print(f"Received: {response}")
asyncio.run(client())
```
## Browser-based apps
For browser-based apps (like Streamlit), users interact directly through the web interface. The app URL is accessible in a browser, and users interact with the UI directly - no API calls needed from other services.
To access a browser-based app:
1. Deploy the app
2. Navigate to the app URL in a browser
3. Interact with the UI directly
## Best practices
1. **Use `depends_on`**: Always specify dependencies to ensure proper deployment order.
2. **Handle errors**: Implement proper error handling for HTTP requests.
3. **Use async clients**: Use async HTTP clients (`httpx.AsyncClient`) in async contexts.
4. **Initialize Flyte**: For apps calling tasks, initialize Flyte in the app's startup.
5. **Endpoint access**: Use `app_env.endpoint` or `AppEndpoint` parameter for accessing app URLs.
6. **Authentication**: Consider authentication when apps call each other (set `requires_auth=True` if needed).
7. **Webhook security**: Secure webhooks with auth, validation, and HTTPS.
8. **WebSocket robustness**: Implement connection management, heartbeats, and rate limiting.
## Summary
| Pattern | Use Case | Implementation |
|---------|----------|----------------|
| Task β App | Batch processing using inference services | HTTP requests from task |
| App β Task | Webhooks, APIs triggering workflows | Flyte SDK in app |
| App β App | Microservices, proxies, agent routers, LLM routers | HTTP requests between apps |
| Browser β App | User-facing dashboards | Direct browser access |
Choose the pattern that best fits your architecture and requirements.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/secret-based-authentication ===
# Secret-based authentication
In this guide, we'll deploy a FastAPI app that uses API key authentication with Flyte secrets. This allows you to invoke the endpoint from the public internet securely without exposing API keys in your code.
## Create the secret
Before defining and deploying the app, you need to create the `API_KEY` secret in Flyte. This secret will store your API key securely.
Create the secret using the Flyte CLI:
```bash
flyte create secret API_KEY
```
For example:
```bash
flyte create secret API_KEY my-secret-api-key-12345
```
> [!NOTE]
> The secret name `API_KEY` must match the key specified in the `flyte.Secret()` call in your code. The secret will be available to your app as the environment variable specified in `as_env_var`.
## Define the FastAPI app
Here's a simple FastAPI app that uses `HTTPAuthorizationCredentials` to authenticate requests using a secret stored in Flyte:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Basic FastAPI authentication using dependency injection."""
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# Get API key from environment variable (loaded from Flyte secret)
# The secret must be created using: flyte create secret API_KEY
API_KEY = os.getenv("API_KEY")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if not API_KEY:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="API_KEY not configured",
)
if credentials.credentials != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
app = FastAPI(title="Authenticated API")
@app.get("/public")
async def public_endpoint():
"""Public endpoint that doesn't require authentication."""
return {"message": "This is public"}
@app.get("/protected")
async def protected_endpoint(
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
"""Protected endpoint that requires authentication."""
return {
"message": "This is protected",
"user": credentials.credentials,
}
env = FastAPIAppEnvironment(
name="authenticated-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False, # We handle auth in the app
secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY"),
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/basic_auth.py*
As you can see, we:
1. Define a `FastAPI` app
2. Create a `verify_token` function that verifies the API key from the Bearer token
3. Define endpoints that use the `verify_token` function to authenticate requests
4. Configure the `FastAPIAppEnvironment` with:
- `requires_auth=False` - This allows the endpoint to be reached without going through Flyte's authentication, since we're handling authentication ourselves using the `API_KEY` secret
- `secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY")` - This injects the secret value into the `API_KEY` environment variable at runtime
The key difference from using `env_vars` is that secrets are stored securely in Flyte's secret store and injected at runtime, rather than being passed as plain environment variables.
## Deploy the FastAPI app
Once the secret is created, you can deploy the FastAPI app. Make sure your `config.yaml` file is in the same directory as your script, then run:
```bash
python basic_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve basic_auth.py
```
Deploying the application will stream the status to the console and display the app URL:
```
β¨ Deploying Application: authenticated-api
π Console URL: https:///console/projects/my-project/domains/development/apps/fastapi-with-auth
[Status] Pending: App is pending deployment
[Status] Started: Service is ready
π Deployed Endpoint: https://rough-meadow-97cf5.apps.
```
## Invoke the endpoint
Once deployed, you can invoke the authenticated endpoint using curl:
```bash
curl -X GET "https://rough-meadow-97cf5.apps./protected" \
-H "Authorization: Bearer "
```
Replace `` with the actual API key value you used when creating the secret.
For example, if you created the secret with value `my-secret-api-key-12345`:
```bash
curl -X GET "https://rough-meadow-97cf5.apps./protected" \
-H "Authorization: Bearer my-secret-api-key-12345"
```
You should receive a response:
```json
{
"message": "This is protected",
"user": "my-secret-api-key-12345"
}
```
## Authentication for vLLM and SGLang apps
Both vLLM and SGLang apps support API key authentication through their native `--api-key` argument. This allows you to secure your LLM endpoints while keeping them accessible from the public internet.
### Create the authentication secret
Create a secret to store your API key:
```bash
flyte create secret AUTH_SECRET
```
For example:
```bash
flyte create secret AUTH_SECRET my-llm-api-key-12345
```
### Deploy vLLM app with authentication
Here's how to deploy a vLLM app with API key authentication:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b45",
# ]
# ///
"""vLLM app with API key authentication."""
import pathlib
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# The secret must be created using: flyte create secret AUTH_SECRET
vllm_app = VLLMAppEnvironment(
name="vllm-app-with-auth",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by vLLM
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# Disable Union's platform-level authentication so you can access the
# endpoint from the public internet
requires_auth=False,
# Inject the secret as an environment variable
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
# Pass the API key to vLLM's --api-key argument
# The $AUTH_SECRET will be replaced with the actual secret value at runtime
extra_args=[
"--api-key", "$AUTH_SECRET",
],
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(vllm_app)
print(f"Deployed vLLM app: {app.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_with_auth.py*
Key points:
1. **`requires_auth=False`** - Disables Union's platform-level authentication so the endpoint can be accessed from the public internet
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to vLLM's `--api-key` argument. The `$AUTH_SECRET` will be replaced with the actual secret value at runtime
Deploy the app:
```bash
python vllm_with_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve vllm_with_auth.py
```
### Deploy SGLang app with authentication
Here's how to deploy a SGLang app with API key authentication:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-sglang>=2.0.0b45",
# ]
# ///
"""SGLang app with API key authentication."""
import pathlib
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# The secret must be created using: flyte create secret AUTH_SECRET
sglang_app = SGLangAppEnvironment(
name="sglang-with-auth",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by SGLang
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# Disable Union's platform-level authentication so you can access the
# endpoint from the public internet
requires_auth=False,
# Inject the secret as an environment variable
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
# Pass the API key to SGLang's --api-key argument
# The $AUTH_SECRET will be replaced with the actual secret value at runtime
extra_args=[
"--api-key", "$AUTH_SECRET",
],
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(sglang_app)
print(f"Deployed SGLang app: {app.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_with_auth.py*
The configuration is similar to vLLM:
1. **`requires_auth=False`** - Disables Union's platform-level authentication
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to SGLang's `--api-key` argument
Deploy the app:
```bash
python sglang_with_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve sglang_with_auth.py
```
### Invoke authenticated LLM endpoints
Once deployed, you can invoke the authenticated endpoints using the OpenAI-compatible API format. Both vLLM and SGLang expose OpenAI-compatible endpoints.
For example, to make a chat completion request:
```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer " \
-d '{
"model": "qwen3-0.6b",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
Replace `` with the actual API key value you used when creating the secret.
For example, if you created the secret with value `my-llm-api-key-12345`:
```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer my-llm-api-key-12345" \
-d '{
"model": "qwen3-0.6b",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
You should receive a response with the model's completion.
> [!NOTE]
> The `$AUTH_SECRET` syntax in `extra_args` is automatically replaced with the actual secret value at runtime. This ensures the API key is never exposed in your code or configuration files.
## Accessing Swagger documentation
The app also includes a public health check endpoint and Swagger UI documentation:
- **Health check**: `https://your-app-url/health`
- **Swagger UI**: `https://your-app-url/docs`
- **ReDoc**: `https://your-app-url/redoc`
The Swagger UI will show an "Authorize" button where you can enter your Bearer token to test authenticated endpoints directly from the browser.
## Security best practices
1. **Use strong API keys**: Generate cryptographically secure random strings for your API keys
2. **Rotate keys regularly**: Periodically rotate your API keys for better security
3. **Scope secrets appropriately**: Use project/domain scoping when creating secrets if you want to limit access:
```bash
flyte create secret --project my-project --domain development API_KEY my-secret-value
```
4. **Never commit secrets**: Always use Flyte secrets for API keys, never hardcode them in your code
5. **Use HTTPS**: Always use HTTPS in production (Flyte apps are served over HTTPS by default)
## Troubleshooting
**Authentication failing:**
- Verify the secret exists: `flyte get secret API_KEY`
- Check that the secret key name matches exactly (case-sensitive)
- Ensure you're using the correct Bearer token value
- Verify the `as_env_var` parameter matches the environment variable name in your code
**Secret not found:**
- Make sure you've created the secret before deploying the app
- Check the secret scope (organization vs project/domain) matches your app's project/domain
- Verify the secret name matches exactly (should be `API_KEY`)
**App not starting:**
- Check container logs for errors
- Verify all dependencies are installed in the image
- Ensure the secret is accessible in the app's project/domain
**LLM app authentication not working:**
- Verify the secret exists: `flyte get secret AUTH_SECRET`
- Check that `$AUTH_SECRET` is correctly specified in `extra_args` (note the `$` prefix)
- Ensure the secret name matches exactly (case-sensitive) in both the `flyte.Secret()` call and `extra_args`
- For vLLM, verify the `--api-key` argument is correctly passed
- For SGLang, verify the `--api-key` argument is correctly passed
- Check that `requires_auth=False` is set to allow public access
## Next steps
- Learn more about [managing secrets](../task-configuration/secrets) in Flyte
- See [app usage patterns](./app-usage-patterns#call-task-from-app-webhooks--apis) for webhook examples and authentication patterns
- Learn about [vLLM apps](./vllm-app) and [SGLang apps](./sglang-app) for serving LLMs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/streamlit-app ===
# Streamlit app
Streamlit is a popular framework for building interactive web applications and dashboards. Flyte makes it easy to deploy Streamlit apps as long-running services.
## Basic Streamlit app
The simplest way to deploy a Streamlit app is to use the built-in Streamlit "hello" demo:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""A basic Streamlit app using the built-in hello demo."""
# {{docs-fragment app-definition}}
import flyte
import flyte.app
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1")
app_env = flyte.app.AppEnvironment(
name="streamlit-hello",
image=image,
args="streamlit hello --server.port 8080",
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.deploy(app_env)
print(f"Deployed app: {app[0].summary_repr()}")
# {{/docs-fragment app-definition}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/basic_streamlit.py*
This just serves the built-in Streamlit "hello" demo.
## Single-file Streamlit app
For a single-file Streamlit app, you can wrap the app code in a function and use the `args` parameter to specify the command to run the app.
Note that the command is running the file itself, and uses the `--server` flag to start the server.
This is useful when you have a relatively small and simple app that you want to deploy as a single file.
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "streamlit",
# ]
# ///
"""A single-script Streamlit app example."""
import sys
from pathlib import Path
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Simple Streamlit App", page_icon="π")
st.title("Hello from Streamlit!")
st.write("This is a simple single-script Streamlit app.")
name = st.text_input("What's your name?", "World")
st.write(f"Hello, {name}!")
if st.button("Click me!"):
st.balloons()
st.success("Button clicked!")
# {{/docs-fragment streamlit-app}}
file_name = Path(__file__).name
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="streamlit-single-script",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1"),
args=[
"streamlit",
"run",
file_name,
"--server.port",
"8080",
"--",
"--server",
],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import logging
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(
root_dir=Path(__file__).parent,
log_level=logging.DEBUG,
)
app = flyte.serve(app_env)
print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/single_file_streamlit.py*
Note that the `if __name__ == "__main__"` block is used to both serve the `AppEnvironment` *and* run the app code via
the `streamlit run` command using the `--server` flag.
## Multi-file Streamlit app
When your streamlit application grows more complex, you may want to split your app into multiple files.
For a multi-file Streamlit app, use the `include` parameter to bundle your app files:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""A custom Streamlit app with multiple files."""
import pathlib
import flyte
import flyte.app
# {{docs-fragment app-env}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
app_env = flyte.app.AppEnvironment(
name="streamlit-multi-file-app",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py"], # Include your app files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.deploy(app_env)
print(f"Deployed app: {app[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/multi_file_streamlit.py*
Where your project structure looks like this:
```
project/
βββ main.py # Main Streamlit app
βββ utils.py # Utility functions
βββ components.py # Reusable components
```
Your `main.py` file would contain your Streamlit app code:
```
import os
import streamlit as st
from utils import generate_data
# {{docs-fragment streamlit-app}}
all_columns = ["Apples", "Orange", "Pineapple"]
with st.container(border=True):
columns = st.multiselect("Columns", all_columns, default=all_columns)
all_data = st.cache_data(generate_data)(columns=all_columns, seed=101)
data = all_data[columns]
tab1, tab2 = st.tabs(["Chart", "Dataframe"])
tab1.line_chart(data, height=250)
tab2.dataframe(data, height=250, use_container_width=True)
st.write(f"Environment: {os.environ}")
# {{/docs-fragment streamlit-app}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/main.py*
## Example: Data visualization dashboard
Here's a complete example of a Streamlit dashboard, all in a single file.
Define the streamlit app in the `main` function:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "streamlit",
# "pandas",
# "numpy",
# ]
# ///
"""A data visualization dashboard example using Streamlit."""
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Sales Dashboard", page_icon="π")
st.title("Sales Dashboard")
# Load data
@st.cache_data
def load_data():
return pd.DataFrame({
"date": pd.date_range("2024-01-01", periods=100, freq="D"),
"sales": np.random.randint(1000, 5000, 100),
})
data = load_data()
# Sidebar filters
st.sidebar.header("Filters")
start_date = st.sidebar.date_input("Start date", value=data["date"].min())
end_date = st.sidebar.date_input("End date", value=data["date"].max())
# Filter data
filtered_data = data[
(data["date"] >= pd.Timestamp(start_date)) &
(data["date"] <= pd.Timestamp(end_date))
]
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Sales", f"${filtered_data['sales'].sum():,.0f}")
with col2:
st.metric("Average Sales", f"${filtered_data['sales'].mean():,.0f}")
with col3:
st.metric("Days", len(filtered_data))
# Chart
st.line_chart(filtered_data.set_index("date")["sales"])
# {{/docs-fragment streamlit-app}}
# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="sales-dashboard",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
),
args=["streamlit run", file_name, "--server.port", "8080", "--", "--server"],
port=8080,
resources=flyte.Resources(cpu="2", memory="2Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment serve}}
if __name__ == "__main__":
import logging
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(
root_dir=Path(__file__).parent,
log_level=logging.DEBUG,
)
app = flyte.serve(app_env)
print(f"Dashboard URL: {app.url}")
# {{/docs-fragment serve}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/data_visualization_dashboard.py*
Define the `AppEnvironment` to serve the app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "streamlit",
# "pandas",
# "numpy",
# ]
# ///
"""A data visualization dashboard example using Streamlit."""
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Sales Dashboard", page_icon="π")
st.title("Sales Dashboard")
# Load data
@st.cache_data
def load_data():
return pd.DataFrame({
"date": pd.date_range("2024-01-01", periods=100, freq="D"),
"sales": np.random.randint(1000, 5000, 100),
})
data = load_data()
# Sidebar filters
st.sidebar.header("Filters")
start_date = st.sidebar.date_input("Start date", value=data["date"].min())
end_date = st.sidebar.date_input("End date", value=data["date"].max())
# Filter data
filtered_data = data[
(data["date"] >= pd.Timestamp(start_date)) &
(data["date"] <= pd.Timestamp(end_date))
]
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Sales", f"${filtered_data['sales'].sum():,.0f}")
with col2:
st.metric("Average Sales", f"${filtered_data['sales'].mean():,.0f}")
with col3:
st.metric("Days", len(filtered_data))
# Chart
st.line_chart(filtered_data.set_index("date")["sales"])
# {{/docs-fragment streamlit-app}}
# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="sales-dashboard",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
),
args=["streamlit run", file_name, "--server.port", "8080", "--", "--server"],
port=8080,
resources=flyte.Resources(cpu="2", memory="2Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment serve}}
if __name__ == "__main__":
import logging
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(
root_dir=Path(__file__).parent,
log_level=logging.DEBUG,
)
app = flyte.serve(app_env)
print(f"Dashboard URL: {app.url}")
# {{/docs-fragment serve}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/data_visualization_dashboard.py*
And finally the app serving logic:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "streamlit",
# "pandas",
# "numpy",
# ]
# ///
"""A data visualization dashboard example using Streamlit."""
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Sales Dashboard", page_icon="π")
st.title("Sales Dashboard")
# Load data
@st.cache_data
def load_data():
return pd.DataFrame({
"date": pd.date_range("2024-01-01", periods=100, freq="D"),
"sales": np.random.randint(1000, 5000, 100),
})
data = load_data()
# Sidebar filters
st.sidebar.header("Filters")
start_date = st.sidebar.date_input("Start date", value=data["date"].min())
end_date = st.sidebar.date_input("End date", value=data["date"].max())
# Filter data
filtered_data = data[
(data["date"] >= pd.Timestamp(start_date)) &
(data["date"] <= pd.Timestamp(end_date))
]
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Sales", f"${filtered_data['sales'].sum():,.0f}")
with col2:
st.metric("Average Sales", f"${filtered_data['sales'].mean():,.0f}")
with col3:
st.metric("Days", len(filtered_data))
# Chart
st.line_chart(filtered_data.set_index("date")["sales"])
# {{/docs-fragment streamlit-app}}
# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="sales-dashboard",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
),
args=["streamlit run", file_name, "--server.port", "8080", "--", "--server"],
port=8080,
resources=flyte.Resources(cpu="2", memory="2Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment serve}}
if __name__ == "__main__":
import logging
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(
root_dir=Path(__file__).parent,
log_level=logging.DEBUG,
)
app = flyte.serve(app_env)
print(f"Dashboard URL: {app.url}")
# {{/docs-fragment serve}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/data_visualization_dashboard.py*
## Best practices
1. **Use `include` for custom apps**: Always include your app files when deploying custom Streamlit code
2. **Set the port correctly**: Ensure your Streamlit app uses `--server.port 8080` (or match your `port` setting)
3. **Cache data**: Use `@st.cache_data` for expensive computations to improve performance
4. **Resource sizing**: Adjust resources based on your app's needs (data size, computations)
5. **Public vs private**: Set `requires_auth=False` for public dashboards, `True` for internal tools
## Troubleshooting
**App not loading:**
- Verify the port matches (use `--server.port 8080`)
- Check that all required files are included
- Review container logs for errors
**Missing dependencies:**
- Ensure all required packages are in your image's pip packages
- Check that file paths in `include` are correct
**Performance issues:**
- Increase CPU/memory resources
- Use Streamlit's caching features (`@st.cache_data`, `@st.cache_resource`)
- Optimize data processing
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/fastapi-app ===
# FastAPI app
FastAPI is a modern, fast web framework for building APIs. Flyte provides `FastAPIAppEnvironment` which makes it easy to deploy FastAPI applications.
## Basic FastAPI app
Here's a simple FastAPI app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""A basic FastAPI app example."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = FastAPI(
title="My API",
description="A simple FastAPI application",
version="1.0.0",
)
# {{/docs-fragment fastapi-app}}
# {{docs-fragment fastapi-env}}
env = FastAPIAppEnvironment(
name="my-fastapi-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment fastapi-env}}
# {{docs-fragment endpoints}}
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# {{/docs-fragment endpoints}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/basic_fastapi.py*
Once deployed, you can:
- Access the API at the generated URL
- View interactive API docs at `/docs` (Swagger UI)
- View alternative docs at `/redoc`
## Serving a machine learning model
Here's an example of serving a scikit-learn model:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# "scikit-learn",
# "joblib",
# ]
# ///
"""Example of serving a machine learning model with FastAPI."""
import os
from contextlib import asynccontextmanager
from pathlib import Path
import joblib
import flyte
from fastapi import FastAPI
from flyte.app.extras import FastAPIAppEnvironment
from pydantic import BaseModel
# {{docs-fragment ml-model}}
app = FastAPI(title="ML Model API")
# Define request/response models
class PredictionRequest(BaseModel):
feature1: float
feature2: float
feature3: float
class PredictionResponse(BaseModel):
prediction: float
probability: float
# Load model (you would typically load this from storage)
model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
model_path = os.getenv("MODEL_PATH", "/app/models/model.joblib")
# In production, load from your storage
if os.path.exists(model_path):
with open(model_path, "rb") as f:
model = joblib.load(f)
yield
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
# Make prediction
# prediction = model.predict([[request.feature1, request.feature2, request.feature3]])
# Dummy prediction for demo
prediction = 0.85
probability = 0.92
return PredictionResponse(
prediction=prediction,
probability=probability,
)
env = FastAPIAppEnvironment(
name="ml-model-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"scikit-learn",
"pydantic",
"joblib",
),
parameters=[
flyte.app.Parameter(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.joblib"),
mount="/app/models",
env_var="MODEL_PATH",
),
],
resources=flyte.Resources(cpu=2, memory="2Gi"),
requires_auth=False,
)
# {{/docs-fragment ml-model}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"API URL: {app_deployment[0].url}")
print(f"Swagger docs: {app_deployment[0].url}/docs")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/ml_model_serving.py*
## Accessing Swagger documentation
FastAPI automatically generates interactive API documentation. Once deployed:
- **Swagger UI**: Access at `{app_url}/docs`
- **ReDoc**: Access at `{app_url}/redoc`
- **OpenAPI JSON**: Access at `{app_url}/openapi.json`
The Swagger UI provides an interactive interface where you can:
- See all available endpoints
- Test API calls directly from the browser
- View request/response schemas
- See example payloads
## Example: REST API with multiple endpoints
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Example REST API with multiple endpoints."""
from pathlib import Path
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment rest-api}}
app = FastAPI(title="Product API")
# Data models
class Product(BaseModel):
id: int
name: str
price: float
class ProductCreate(BaseModel):
name: str
price: float
# In-memory database (use real database in production)
products_db = []
@app.get("/products", response_model=List[Product])
async def get_products():
return products_db
@app.get("/products/{product_id}", response_model=Product)
async def get_product(product_id: int):
product = next((p for p in products_db if p["id"] == product_id), None)
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return product
@app.post("/products", response_model=Product)
async def create_product(product: ProductCreate):
new_product = {
"id": len(products_db) + 1,
"name": product.name,
"price": product.price,
}
products_db.append(new_product)
return new_product
env = FastAPIAppEnvironment(
name="product-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment rest-api}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"API URL: {app_deployment[0].url}")
print(f"Swagger docs: {app_deployment[0].url}/docs")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/rest_api.py*
## Multi-file FastAPI app
Here's an example of a multi-file FastAPI app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "fastapi",
# ]
# ///
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py*
The helper module:
```
# {{docs-fragment helper-function}}
def function():
"""Helper function used by the FastAPI app."""
return {"message": "Hello from module.py!"}
# {{/docs-fragment helper-function}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/module.py*
See [Multi-script apps](./multi-script-apps) for more details on building FastAPI apps with multiple files.
## Local-to-remote model serving
A common ML pattern: train a model with a Flyte pipeline, then serve predictions from it. During local development, the app loads the model from a local file (e.g. `model.pt` saved by your training pipeline). When deployed remotely, Flyte's `Parameter` system automatically resolves the model from the latest training run output.
```python
from contextlib import asynccontextmanager
from pathlib import Path
import os
from fastapi import FastAPI
import flyte
from flyte.app import Parameter, RunOutput
from flyte.app.extras import FastAPIAppEnvironment
MODEL_PATH_ENV = "MODEL_PATH"
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, either local file or remote run output."""
model_path = Path(os.environ.get(MODEL_PATH_ENV, "model.pt"))
model = load_model(model_path)
app.state.model = model
yield
app = FastAPI(title="MNIST Predictor", lifespan=lifespan)
serving_env = FastAPIAppEnvironment(
name="mnist-predictor",
app=app,
parameters=[
# Remote: resolves model from the latest train run and sets MODEL_PATH
Parameter(
name="model",
value=RunOutput(task_name="ml_pipeline.pipeline", type="file", getter=(1,)),
download=True,
env_var=MODEL_PATH_ENV,
),
],
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "torch", "torchvision",
),
resources=flyte.Resources(cpu=1, memory="4Gi"),
)
@app.get("/predict")
async def predict(index: int = 0) -> dict:
return {"prediction": app.state.model(index)}
if __name__ == "__main__":
# Local: skip RunOutput resolution, lifespan falls back to local model.pt
serving_env.parameters = []
local_app = flyte.with_servecontext(mode="local").serve(serving_env)
local_app.activate(wait=True)
```
Locally, the app loads `model.pt` from disk:
```bash
python serve_model.py
```
Remotely, Flyte resolves the model from the latest training run:
```bash
flyte deploy serve_model.py serving_env
```
The key idea: `Parameter` with `RunOutput` bridges the gap between local and remote. Locally, the app falls back to a local file. Remotely, Flyte resolves the model artifact from the latest pipeline run automatically.
## Best practices
1. **Use Pydantic models**: Define request/response models for type safety and automatic validation
2. **Handle errors**: Use HTTPException for proper error responses
3. **Async operations**: Use async/await for I/O operations
4. **Environment variables**: Use environment variables for configuration
5. **Logging**: Add proper logging for debugging and monitoring
6. **Health checks**: Always include a `/health` endpoint
7. **API documentation**: FastAPI auto-generates docs, but add descriptions to your endpoints
## Advanced features
FastAPI supports many features that work with Flyte:
- **Dependencies**: Use FastAPI's dependency injection system
- **Background tasks**: Run background tasks with BackgroundTasks
- **WebSockets**: See [WebSocket-based patterns](./app-usage-patterns#websocket-based-patterns) for details
- **Authentication**: Add authentication middleware (see [secret-based authentication](./secret-based-authentication))
- **CORS**: Configure CORS for cross-origin requests
- **Rate limiting**: Add rate limiting middleware
## Troubleshooting
**App not starting:**
- Check that uvicorn can find your app module
- Verify all dependencies are installed in the image
- Check container logs for startup errors
**Import errors:**
- Ensure all imported modules are available
- Use `include` parameter if you have custom modules
- Check that file paths are correct
**API not accessible:**
- Verify `requires_auth` setting
- Check that the app is listening on the correct port (8080)
- Review network/firewall settings
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/vllm-app ===
# vLLM app
vLLM is a high-performance library for serving large language models (LLMs). Flyte provides `VLLMAppEnvironment` for deploying vLLM model servers.
## Installation
First, install the vLLM plugin:
```bash
pip install flyteplugins-vllm
```
## Basic vLLM app
Here's a simple example serving a HuggingFace model:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b45",
# ]
# ///
"""A simple vLLM app example."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment basic-vllm-app}}
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by vLLM
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
requires_auth=False,
)
# {{/docs-fragment basic-vllm-app}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(vllm_app)
print(f"Deployed vLLM app: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/basic_vllm.py*
## Using prefetched models
You can use models prefetched with `flyte.prefetch`:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b45",
# ]
# override-dependencies = [
# "cel-python; sys_platform == 'never'",
# ]
# ///
"""vLLM app using prefetched models."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment prefetch}}
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_hf_path="Qwen/Qwen3-0.6B", # this is a placeholder
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1", disk="10Gi"),
stream_model=True, # Stream model directly from blob store to GPU
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
# Prefetch the model first
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
app = flyte.serve(
vllm_app.clone_with(
vllm_app.name,
model_hf_path=None,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
)
)
print(f"Deployed vLLM app: {app.url}")
# {{/docs-fragment prefetch}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_with_prefetch.py*
## Model streaming
`VLLMAppEnvironment` supports streaming models directly from blob storage to GPU memory, reducing startup time.
When `stream_model=True` and the `model_path` argument is provided with either a `flyte.io.Dir` or `RunOutput` pointing
to a path in object store:
- Model weights stream directly from storage to GPU
- Faster startup time (no full download required)
- Lower disk space requirements
> [!NOTE]
> The contents of the model directory must be compatible with the vLLM-supported formats, e.g. the HuggingFace model
> serialization format.
## Custom vLLM arguments
Use `extra_args` to pass additional arguments to vLLM:
```python
vllm_app = VLLMAppEnvironment(
name="custom-vllm-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
extra_args=[
"--max-model-len", "8192", # Maximum context length
"--gpu-memory-utilization", "0.8", # GPU memory utilization
"--trust-remote-code", # Trust remote code in models
],
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
# ...
)
```
See the [vLLM documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html) for all available arguments.
## Using the OpenAI-compatible API
Once deployed, your vLLM app exposes an OpenAI-compatible API:
```python
from openai import OpenAI
client = OpenAI(
base_url="https://your-app-url/v1", # vLLM endpoint
api_key="your-api-key", # If you passed an --api-key argument
)
response = client.chat.completions.create(
model="qwen3-0.6b", # Your model_id
messages=[
{"role": "user", "content": "Hello, how are you?"}
],
)
print(response.choices[0].message.content)
```
> [!TIP]
> If you passed an `--api-key` argument, you can use the `api_key` parameter to authenticate your requests.
> See [here](./secret-based-authentication#deploy-vllm-app-with-authentication) for more details on how to pass auth secrets to your app.
## Multi-GPU inference (Tensor Parallelism)
For larger models, use multiple GPUs with tensor parallelism:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b45",
# ]
# ///
"""vLLM app with multi-GPU tensor parallelism."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment multi-gpu}}
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
model_hf_path="meta-llama/Llama-2-70b-hf",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # 4 GPUs for tensor parallelism
disk="100Gi",
),
extra_args=[
"--tensor-parallel-size", "4", # Use 4 GPUs
"--max-model-len", "4096",
"--gpu-memory-utilization", "0.9",
],
requires_auth=False,
)
# {{/docs-fragment multi-gpu}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(vllm_app)
print(f"Deployed vLLM app: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_multi_gpu.py*
The `tensor-parallel-size` should match the number of GPUs specified in resources.
## Model sharding with prefetch
You can prefetch and shard models for multi-GPU inference:
```python
# Prefetch with sharding configuration
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=flyte.prefetch.ShardConfig(
engine="vllm",
args=flyte.prefetch.VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
)
run.wait()
# Use the sharded model
vllm_app = VLLMAppEnvironment(
name="sharded-llm-app",
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
model_id="llama-2-70b",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4", disk="100Gi"),
extra_args=["--tensor-parallel-size", "4"],
stream_model=True,
)
```
See [Prefetching models](../serve-and-deploy-apps/prefetching-models) for more details on sharding.
## Autoscaling
vLLM apps work well with autoscaling:
```python
vllm_app = VLLMAppEnvironment(
name="autoscaling-llm-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale to zero when idle
scaledown_after=600, # 10 minutes idle before scaling down
),
# ...
)
```
## Best practices
1. **Use prefetching**: Prefetch models for faster deployment and better reproducibility
2. **Enable streaming**: Use `stream_model=True` to reduce startup time and disk usage
3. **Right-size GPUs**: Match GPU memory to model size
4. **Configure memory utilization**: Use `--gpu-memory-utilization` to control memory usage
5. **Use tensor parallelism**: For large models, use multiple GPUs with `tensor-parallel-size`
6. **Set autoscaling**: Use appropriate idle TTL to balance cost and performance
7. **Limit context length**: Use `--max-model-len` for smaller models to reduce memory usage
## Troubleshooting
**Model loading fails:**
- Verify GPU memory is sufficient for the model
- Check that the model path or HuggingFace path is correct
- Review container logs for detailed error messages
**Out of memory errors:**
- Reduce `--max-model-len`
- Lower `--gpu-memory-utilization`
- Use a smaller model or more GPUs
**Slow startup:**
- Enable `stream_model=True` for faster loading
- Prefetch models before deployment
- Use faster storage backends
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/sglang-app ===
# SGLang app
SGLang is a fast structured generation library for large language models (LLMs). Flyte provides `SGLangAppEnvironment` for deploying SGLang model servers.
## Installation
First, install the SGLang plugin:
```bash
pip install flyteplugins-sglang
```
## Basic SGLang app
Here's a simple example serving a HuggingFace model:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-sglang>=2.0.0b45",
# ]
# ///
"""A simple SGLang app example."""
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# {{docs-fragment basic-sglang-app}}
sglang_app = SGLangAppEnvironment(
name="my-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by SGLang
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
requires_auth=False,
)
# {{/docs-fragment basic-sglang-app}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(sglang_app)
print(f"Deployed SGLang app: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/basic_sglang.py*
## Using prefetched models
You can use models prefetched with `flyte.prefetch`:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-sglang>=2.0.0b45",
# ]
# ///
"""SGLang app using prefetched models."""
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# {{docs-fragment prefetch}}
# Use the prefetched model
sglang_app = SGLangAppEnvironment(
name="my-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B", # this is a placeholder
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1", disk="10Gi"),
stream_model=True, # Stream model directly from blob store to GPU
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
# Prefetch the model first
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
app = flyte.serve(
sglang_app.clone_with(
sglang_app.name,
model_hf_path=None,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
)
)
print(f"Deployed SGLang app: {app.url}")
# {{/docs-fragment prefetch}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_with_prefetch.py*
## Model streaming
`SGLangAppEnvironment` supports streaming models directly from blob storage to GPU memory, reducing startup time.
When `stream_model=True` and the `model_path` argument is provided with either a `flyte.io.Dir` or `RunOutput` pointing
to a path in object store:
- Model weights stream directly from storage to GPU
- Faster startup time (no full download required)
- Lower disk space requirements
> [!NOTE]
> The contents of the model directory must be compatible with the SGLang-supported formats, e.g. the HuggingFace model
> serialization format.
## Custom SGLang arguments
Use `extra_args` to pass additional arguments to SGLang:
```python
sglang_app = SGLangAppEnvironment(
name="custom-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
extra_args=[
"--max-model-len", "8192", # Maximum context length
"--mem-fraction-static", "0.8", # Memory fraction for static allocation
"--trust-remote-code", # Trust remote code in models
],
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
# ...
)
```
See the [SGLang server arguments documentation](https://docs.sglang.io/advanced_features/server_arguments.html) for all available options.
## Using the OpenAI-compatible API
Once deployed, your SGLang app exposes an OpenAI-compatible API:
```python
from openai import OpenAI
client = OpenAI(
base_url="https://your-app-url/v1", # SGLang endpoint
api_key="your-api-key", # If you passed an --api-key argument
)
response = client.chat.completions.create(
model="qwen3-0.6b", # Your model_id
messages=[
{"role": "user", "content": "Hello, how are you?"}
],
)
print(response.choices[0].message.content)
```
> [!TIP]
> If you passed an `--api-key` argument, you can use the `api_key` parameter to authenticate your requests.
> See [here](./secret-based-authentication#deploy-sglang-app-with-authentication) for more details on how to pass auth secrets to your app.
## Multi-GPU inference (Tensor Parallelism)
For larger models, use multiple GPUs with tensor parallelism:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-sglang>=2.0.0b45",
# ]
# ///
"""SGLang app with multi-GPU tensor parallelism."""
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# {{docs-fragment multi-gpu}}
sglang_app = SGLangAppEnvironment(
name="multi-gpu-sglang-app",
model_hf_path="meta-llama/Llama-2-70b-hf",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # 4 GPUs for tensor parallelism
disk="100Gi",
),
extra_args=[
"--tp", "4", # Tensor parallelism size (4 GPUs)
"--max-model-len", "4096",
"--mem-fraction-static", "0.9",
],
requires_auth=False,
)
# {{/docs-fragment multi-gpu}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(sglang_app)
print(f"Deployed SGLang app: {app.url}")
# {{/docs-fragment deploy}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_multi_gpu.py*
The tensor parallelism size (`--tp`) should match the number of GPUs specified in resources.
## Model sharding with prefetch
You can prefetch and shard models for multi-GPU inference using SGLang's sharding:
```python
# Prefetch with sharding configuration
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=flyte.prefetch.ShardConfig(
engine="vllm",
args=flyte.prefetch.VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
)
run.wait()
# Use the sharded model
sglang_app = SGLangAppEnvironment(
name="sharded-sglang-app",
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
model_id="llama-2-70b",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4", disk="100Gi"),
extra_args=["--tp", "4"],
stream_model=True,
)
```
See [Prefetching models](../serve-and-deploy-apps/prefetching-models) for more details on sharding.
## Autoscaling
SGLang apps work well with autoscaling:
```python
sglang_app = SGLangAppEnvironment(
name="autoscaling-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale to zero when idle
scaledown_after=600, # 10 minutes idle before scaling down
),
# ...
)
```
## Structured generation
SGLang is particularly well-suited for structured generation tasks. The deployed app supports standard OpenAI API calls, and you can use SGLang's advanced features through the API.
## Best practices
1. **Use prefetching**: Prefetch models for faster deployment and better reproducibility
2. **Enable streaming**: Use `stream_model=True` to reduce startup time and disk usage
3. **Right-size GPUs**: Match GPU memory to model size
4. **Use tensor parallelism**: For large models, use multiple GPUs with `--tp`
5. **Set autoscaling**: Use appropriate idle TTL to balance cost and performance
6. **Configure memory**: Use `--mem-fraction-static` to control memory allocation
7. **Limit context length**: Use `--max-model-len` for smaller models to reduce memory usage
## Troubleshooting
**Model loading fails:**
- Verify GPU memory is sufficient for the model
- Check that the model path or HuggingFace path is correct
- Review container logs for detailed error messages
**Out of memory errors:**
- Reduce `--max-model-len`
- Lower `--mem-fraction-static`
- Use a smaller model or more GPUs
**Slow startup:**
- Enable `stream_model=True` for faster loading
- Prefetch models before deployment
- Use faster storage backends
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps ===
# Serve and deploy apps
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
Flyte provides two main ways to deploy apps: **serve** (for development) and **deploy** (for production). This section covers both methods and their differences.
## Serve vs Deploy
### `flyte serve`
Serving is designed for development and iteration:
- **Dynamic parameter modification**: You can override app parameters when serving
- **Quick iteration**: Faster feedback loop for development
- **Interactive**: Better suited for testing and experimentation
### `flyte deploy`
Deployment is designed for production use:
- **Immutable**: Apps are deployed with fixed configurations
- **Production-ready**: Optimized for stability and reproducibility
## Using Python SDK
### Serve
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve and deploy examples for the _index.md documentation."""
import flyte
import flyte.app
# {{docs-fragment serve-example}}
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"Served at: {app.url}")
# {{/docs-fragment serve-example}}
# {{docs-fragment deploy-example}}
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed app URL from the deployment
for deployed_env in deployments[0].envs.values():
print(f"Deployed: {deployed_env.deployed_app.url}")
# {{/docs-fragment deploy-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_and_deploy_examples.py*
### Deploy
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve and deploy examples for the _index.md documentation."""
import flyte
import flyte.app
# {{docs-fragment serve-example}}
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"Served at: {app.url}")
# {{/docs-fragment serve-example}}
# {{docs-fragment deploy-example}}
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed app URL from the deployment
for deployed_env in deployments[0].envs.values():
print(f"Deployed: {deployed_env.deployed_app.url}")
# {{/docs-fragment deploy-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_and_deploy_examples.py*
## Using the CLI
### Serve
```bash
flyte serve path/to/app.py app_env
```
### Deploy
```bash
flyte deploy path/to/app.py app_env
```
## Next steps
- **Serve and deploy apps > How app serving works**: Understanding the serve process and configuration options
- **Serve and deploy apps > How app deployment works**: Understanding the deploy process and configuration options
- **Serve and deploy apps > Activating and deactivating apps**: Managing app lifecycle
- **Basic project: RAG**: Train a model with tasks and serve it via FastAPI
- **Serve and deploy apps > Prefetching models**: Download and shard HuggingFace models for vLLM and SGLang
## Subpages
- **Serve and deploy apps > How app serving works**
- **Serve and deploy apps > How app deployment works**
- **Serve and deploy apps > Activating and deactivating apps**
- **Serve and deploy apps > Prefetching models**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/how-app-serving-works ===
# How app serving works
Serving is the recommended way to deploy apps during development. It provides a faster feedback loop and allows you to dynamically modify parameters.
## Overview
When you serve an app, the following happens:
1. **Code bundling**: Your app code is bundled and prepared
2. **Image building**: Container images are built (if needed)
3. **Deployment**: The app is deployed to your Flyte cluster
4. **Activation**: The app is automatically activated and ready to use
5. **URL generation**: A URL is generated for accessing the app
## Using the Python SDK
The simplest way to serve an app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve examples for the how-app-serving-works.md documentation."""
import logging
import flyte
import flyte.app
# {{docs-fragment basic-serve}}
app_env = flyte.app.AppEnvironment(
name="my-dev-app",
parameters=[flyte.app.Parameter(name="model_path", value="s3://bucket/models/model.pkl")],
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment basic-serve}}
# {{docs-fragment override-parameters}}
app = flyte.with_servecontext(
input_values={
"my-dev-app": {
"model_path": "s3://bucket/models/test-model.pkl",
}
}
).serve(app_env)
# {{/docs-fragment override-parameters}}
# {{docs-fragment advanced-serving}}
app = flyte.with_servecontext(
version="v1.0.0",
project="my-project",
domain="development",
env_vars={"LOG_LEVEL": "DEBUG"},
input_values={"app-name": {"input": "value"}},
cluster_pool="dev-pool",
log_level=logging.INFO,
log_format="json",
dry_run=False,
).serve(app_env)
# {{/docs-fragment advanced-serving}}
# {{docs-fragment return-value}}
app = flyte.serve(app_env)
print(f"URL: {app.url}")
print(f"Endpoint: {app.endpoint}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment return-value}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_examples.py*
## Overriding parameters
One key advantage of serving is the ability to override parameters dynamically:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve examples for the how-app-serving-works.md documentation."""
import logging
import flyte
import flyte.app
# {{docs-fragment basic-serve}}
app_env = flyte.app.AppEnvironment(
name="my-dev-app",
parameters=[flyte.app.Parameter(name="model_path", value="s3://bucket/models/model.pkl")],
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment basic-serve}}
# {{docs-fragment override-parameters}}
app = flyte.with_servecontext(
input_values={
"my-dev-app": {
"model_path": "s3://bucket/models/test-model.pkl",
}
}
).serve(app_env)
# {{/docs-fragment override-parameters}}
# {{docs-fragment advanced-serving}}
app = flyte.with_servecontext(
version="v1.0.0",
project="my-project",
domain="development",
env_vars={"LOG_LEVEL": "DEBUG"},
input_values={"app-name": {"input": "value"}},
cluster_pool="dev-pool",
log_level=logging.INFO,
log_format="json",
dry_run=False,
).serve(app_env)
# {{/docs-fragment advanced-serving}}
# {{docs-fragment return-value}}
app = flyte.serve(app_env)
print(f"URL: {app.url}")
print(f"Endpoint: {app.endpoint}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment return-value}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_examples.py*
This is useful for:
- Testing different configurations
- Using different models or data sources
- A/B testing during development
## Advanced serving options
Use `with_servecontext()` for more control over the serving process:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve examples for the how-app-serving-works.md documentation."""
import logging
import flyte
import flyte.app
# {{docs-fragment basic-serve}}
app_env = flyte.app.AppEnvironment(
name="my-dev-app",
parameters=[flyte.app.Parameter(name="model_path", value="s3://bucket/models/model.pkl")],
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment basic-serve}}
# {{docs-fragment override-parameters}}
app = flyte.with_servecontext(
input_values={
"my-dev-app": {
"model_path": "s3://bucket/models/test-model.pkl",
}
}
).serve(app_env)
# {{/docs-fragment override-parameters}}
# {{docs-fragment advanced-serving}}
app = flyte.with_servecontext(
version="v1.0.0",
project="my-project",
domain="development",
env_vars={"LOG_LEVEL": "DEBUG"},
input_values={"app-name": {"input": "value"}},
cluster_pool="dev-pool",
log_level=logging.INFO,
log_format="json",
dry_run=False,
).serve(app_env)
# {{/docs-fragment advanced-serving}}
# {{docs-fragment return-value}}
app = flyte.serve(app_env)
print(f"URL: {app.url}")
print(f"Endpoint: {app.endpoint}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment return-value}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_examples.py*
## Using CLI
You can also serve apps from the command line:
```bash
flyte serve path/to/app.py app
```
Where `app` is the variable name of the `AppEnvironment` object.
## Return value
`flyte.serve()` returns an `App` object with:
- `url`: The app's URL
- `endpoint`: The app's endpoint URL
- `deployment_status`: Current status of the app
- `name`: App name
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Serve examples for the how-app-serving-works.md documentation."""
import logging
import flyte
import flyte.app
# {{docs-fragment basic-serve}}
app_env = flyte.app.AppEnvironment(
name="my-dev-app",
parameters=[flyte.app.Parameter(name="model_path", value="s3://bucket/models/model.pkl")],
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment basic-serve}}
# {{docs-fragment override-parameters}}
app = flyte.with_servecontext(
input_values={
"my-dev-app": {
"model_path": "s3://bucket/models/test-model.pkl",
}
}
).serve(app_env)
# {{/docs-fragment override-parameters}}
# {{docs-fragment advanced-serving}}
app = flyte.with_servecontext(
version="v1.0.0",
project="my-project",
domain="development",
env_vars={"LOG_LEVEL": "DEBUG"},
input_values={"app-name": {"input": "value"}},
cluster_pool="dev-pool",
log_level=logging.INFO,
log_format="json",
dry_run=False,
).serve(app_env)
# {{/docs-fragment advanced-serving}}
# {{docs-fragment return-value}}
app = flyte.serve(app_env)
print(f"URL: {app.url}")
print(f"Endpoint: {app.endpoint}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment return-value}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/serve_examples.py*
## Best practices
1. **Use for development**: App serving is ideal for development and testing.
2. **Override parameters**: Take advantage of parameter overrides for testing different configurations.
3. **Quick iteration**: Use `serve` for rapid development cycles.
4. **Switch to deploy**: Use [deploy](./how-app-deployment-works) for production deployments.
## Troubleshooting
**App not activating:**
- Check cluster connectivity
- Verify app configuration is correct
- Review container logs for errors
**Parameter overrides not working:**
- Verify parameter names match exactly
- Check that parameters are defined in the app environment
- Ensure you're using the `input_values` parameter correctly
**Slow serving:**
- Images may need to be built (first time is slower).
- Large code bundles can slow down deployment.
- Check network connectivity to the cluster.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/how-app-deployment-works ===
# How app deployment works
Deployment is the recommended way to deploy apps to production. It creates versioned, immutable app deployments.
## Overview
When you deploy an app, the following happens:
1. **Code bundling**: Your app code is bundled and prepared
2. **Image building**: Container images are built (if needed)
3. **Deployment**: The app is deployed to your Flyte cluster
4. **Activation**: The app is automatically activated and ready to use
## Using the Python SDK
Deploy an app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
`flyte.deploy()` returns a list of `Deployment` objects. Each `Deployment` contains a dictionary of `DeployedEnvironment` objects (one for each environment deployed, including environment dependencies). For apps, the `DeployedEnvironment` is a `DeployedAppEnvironment` which has a `deployed_app` property of type `App`.
## Deployment plan
Flyte automatically creates a deployment plan that includes:
- The app you're deploying
- All [app environment dependencies](../configure-apps/apps-depending-on-environments) (via `depends_on`)
- Proper deployment order
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
## Overriding App configuration at deployment time
If you need to override the app configuration at deployment time, you can use the `clone_with` method to create a new
app environment with the desired overrides.
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
## Activation/deactivation
Unlike serving, deployment does not automatically activate apps. You need to activate them explicitly:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
See [Activating and deactivating apps](./activating-and-deactivating-apps) for more details.
## Using the CLI
Deploy from the command line:
```bash
flyte deploy path/to/app.py app
```
Where `app` is the variable name of the `AppEnvironment` object.
You can also specify the following options:
```bash
flyte deploy path/to/app.py app \
--version v1.0.0 \
--project my-project \
--domain production \
--dry-run
```
## Example: Full deployment configuration
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
## Best practices
1. **Use for production**: Deploy is designed for production use.
2. **Version everything**: Always specify versions for reproducibility.
3. **Test first**: Test with serve before deploying to production.
4. **Manage dependencies**: Use `depends_on` to manage app dependencies.
5. **Activation strategy**: Have a strategy for activating/deactivating apps.
7. **Use dry-run**: Test deployments with `dry_run=True` first.
8. **Separate environments**: Use different projects/domains for different environments.
9. **Parameter management**: Consider using environment-specific parameter values.
## Deployment status and return value
`flyte.deploy()` returns a list of `Deployment` objects. Each `Deployment` contains a dictionary of `DeployedEnvironment` objects:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Deploy examples for the how-app-deployment-works.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
# {{docs-fragment basic-deploy}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment basic-deploy}}
# {{docs-fragment deployment-plan}}
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
# {{/docs-fragment deployment-plan}}
# {{docs-fragment clone-with}}
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
# {{/docs-fragment clone-with}}
# {{docs-fragment activation-deactivation}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
# {{/docs-fragment activation-deactivation}}
# {{docs-fragment full-deployment}}
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
# {{/docs-fragment full-deployment}}
# {{docs-fragment deployment-status}}
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
# {{/docs-fragment deployment-status}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/deploy_examples.py*
For apps, each `DeployedAppEnvironment` includes:
- `env`: The `AppEnvironment` that was deployed
- `deployed_app`: The `App` object with properties like `url`, `endpoint`, `name`, and `deployment_status`
## Troubleshooting
**Deployment fails:**
- Check that all dependencies are available
- Verify image builds succeed
- Review deployment logs
**App not accessible:**
- Ensure the app is activated
- Check cluster connectivity
- Verify app configuration
**Version conflicts:**
- Use unique versions for each deployment
- Check existing app versions
- Clean up old versions if needed
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/activating-and-deactivating-apps ===
# Activating and deactivating apps
Apps deployed with `flyte.deploy()` need to be explicitly activated before they can serve traffic. Apps served with `flyte.serve()` are automatically activated.
## Activation
### Activate after deployment
After deploying an app, activate it:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
### Activate an app
When you get an app by name, you get the current app instance:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
### Check activation status
Check if an app is active:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
## Deactivation
Deactivate an app when you no longer need it:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
## Lifecycle management
### Typical deployment workflow
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
### Blue-green deployment
For zero-downtime deployments:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
## Using CLI
### Activate
```bash
flyte update app --activate my-app
```
### Deactivate
```bash
flyte update app --deactivate my-app
```
### Check status
```bash
flyte get app my-app
```
Use `--project` and `--domain` to target a specific [project-domain pair](../projects-and-domains).
For all available options, see the [CLI reference](../../api-reference/flyte-cli).
## Best practices
1. **Activate after testing**: Test deployed apps before activating
2. **Version management**: Keep track of which version is active
4. **Blue-green deployments**: Use blue-green for zero-downtime
5. **Monitor**: Monitor apps after activation
6. **Cleanup**: Deactivate and remove old versions periodically
## Automatic activation with serve
Apps served with `flyte.serve()` are automatically activated:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
This is convenient for development but less suitable for production where you want explicit control over activation.
## Example: Complete deployment and activation
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# ]
# ///
"""Activation examples for the activating-and-deactivating-apps.md documentation."""
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-app",
# ...
)
# {{docs-fragment activate-after-deployment}}
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
# {{/docs-fragment activate-after-deployment}}
# {{docs-fragment activate-app}}
app = App.get(name="my-app")
app.activate()
# {{/docs-fragment activate-app}}
# {{docs-fragment check-activation-status}}
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
# {{/docs-fragment check-activation-status}}
# {{docs-fragment deactivation}}
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
# {{/docs-fragment deactivation}}
# {{docs-fragment typical-deployment-workflow}}
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
# {{/docs-fragment typical-deployment-workflow}}
# {{docs-fragment blue-green-deployment}}
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
# {{/docs-fragment blue-green-deployment}}
# {{docs-fragment automatic-activation}}
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
# {{/docs-fragment automatic-activation}}
# {{docs-fragment complete-example}}
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/activation_examples.py*
## Troubleshooting
**App not accessible after activation:**
- Verify activation succeeded
- Check app logs for startup errors
- Verify cluster connectivity
- Check that the app is listening on the correct port
**Activation fails:**
- Check that the app was deployed successfully
- Verify app configuration is correct
- Check cluster resources
- Review deployment logs
**Cannot deactivate:**
- Ensure you have proper permissions
- Check if there are dependencies preventing deactivation
- Verify the app name and version
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/prefetching-models ===
# Prefetching models
Prefetching allows you to download and prepare HuggingFace models (including sharding for multi-GPU inference) before
deploying [vLLM](../build-apps/vllm-app) or [SGLang](../build-apps/sglang-app) apps. This speeds up deployment and ensures models are ready when your app starts.
## Why prefetch?
Prefetching models provides several benefits:
- **Faster deployment**: Models are pre-downloaded, so apps start faster
- **Reproducibility**: Models are versioned and stored in Flyte's object store
- **Sharding support**: Pre-shard models for multi-GPU tensor parallelism
- **Cost efficiency**: Download once, use many times
- **Offline support**: Models are cached in your storage backend
## Basic prefetch
### Using Python SDK
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
### Using CLI
```bash
flyte prefetch hf-model Qwen/Qwen3-0.6B
```
Wait for completion:
```bash
flyte prefetch hf-model Qwen/Qwen3-0.6B --wait
```
## Using prefetched models
Use the prefetched model in your vLLM or SGLang app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
> [!TIP]
> You can also use prefetched models as parameters to your generic `[[AppEnvironment]]`s or `FastAPIAppEnvironment`s.
## Prefetch options
### Custom artifact name
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
### With HuggingFace token
If the model requires authentication:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
The default value for `hf_token_key` is `HF_TOKEN`, where `HF_TOKEN` is the name of the Flyte secret containing your
HuggingFace token. If this secret doesn't exist, you can create a secret using the [flyte create secret CLI](../task-configuration/secrets).
### With resources
By default, the prefetch task uses minimal resources (2 CPUs, 8GB of memory, 50Gi of disk storage), using
filestreaming logic to move the model weights from HuggingFace to your storage backend directly.
In some cases, the HuggingFace model may not support filestreaming, in which case the prefetch task will fallback to
downloading the model weights to the task pod's disk storage first, then uploading them to your storage backend. In this
case, you can specify custom resources for the prefetch task to override the default resources.
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
## Sharding models for multi-GPU
### vLLM sharding
Shard a model for tensor parallelism:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
Currently, the `flyte.prefetch.hf_model` function only supports sharding models
using the `vllm` engine. Once sharded, these models can be loaded with other
frameworks such as `transformers`, `torch`, or `sglang`.
### Using shard config via CLI
You can also use a YAML file for sharding configuration to use with the
`flyte prefetch hf-model` CLI command:
```yaml
# shard_config.yaml
engine: vllm
args:
tensor_parallel_size: 8
dtype: auto
trust_remote_code: true
```
Then run the CLI command:
```bash
flyte prefetch hf-model meta-llama/Llama-2-70b-hf \
--shard-config shard_config.yaml \
--accelerator L40s:8 \
--hf-token-key HF_TOKEN
```
## Using prefetched sharded models
After prefetching and sharding, serve the model in your app:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
## CLI options
Complete CLI usage:
```bash
flyte prefetch hf-model \
--artifact-name \
--architecture \
--task \
--modality text \
--format safetensors \
--model-type transformer \
--short-description "Description" \
--force 0 \
--wait \
--hf-token-key HF_TOKEN \
--cpu 4 \
--mem 16Gi \
--ephemeral-storage 100Gi \
--accelerator L40s:4 \
--shard-config shard_config.yaml
```
## Complete example
Here's a complete example of prefetching and using a model:
```
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-vllm>=2.0.0b49",
# ]
# ///
"""Prefetch examples for the prefetching-models.md documentation."""
import flyte
from flyte.prefetch import ShardConfig, VLLMShardArgs
from flyteplugins.vllm import VLLMAppEnvironment
# {{docs-fragment basic-prefetch}}
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
# {{/docs-fragment basic-prefetch}}
# {{docs-fragment using-prefetched-models}}
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
# {{/docs-fragment using-prefetched-models}}
# {{docs-fragment custom-artifact-name}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
# {{/docs-fragment custom-artifact-name}}
# {{docs-fragment hf-token}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
# {{/docs-fragment hf-token}}
# {{docs-fragment with-resources}}
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
# {{/docs-fragment with-resources}}
# {{docs-fragment vllm-sharding}}
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
# {{/docs-fragment vllm-sharding}}
# {{docs-fragment using-sharded-models}}
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
# {{/docs-fragment using-sharded-models}}
# {{docs-fragment complete-example}}
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
# {{/docs-fragment complete-example}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/serve-and-deploy-apps/prefetch_examples.py*
## Best practices
1. **Prefetch before deployment**: Prefetch models before deploying apps for faster startup
2. **Version models**: Use meaningful artifact names to easily identify the model in object store paths
3. **Shard appropriately**: Shard models for the GPU configuration you'll use for inference
4. **Cache prefetched models**: Once prefetched, models are cached in your storage backend for faster serving
## Troubleshooting
**Prefetch fails:**
- Check HuggingFace token (if required)
- Verify model repo exists and is accessible
- Check resource availability
- Review prefetch task logs
**Sharding fails:**
- Ensure accelerator matches shard config
- Check GPU memory is sufficient
- Verify `tensor_parallel_size` matches GPU count
- Review prefetch task logs for sharding-related errors
**Model not found in app:**
- Verify RunOutput references correct run name
- Check that prefetch completed successfully
- Ensure model_path is set correctly
- Review app startup logs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-agent ===
# Build an agent
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This section covers how to build, deploy, and run agentic AI applications on Flyte. You'll learn how to implement common agent patterns like ReAct and Plan-and-Execute and deploy agents as hosted services.
## Quickstart
Here's how Flyte maps to the agentic world:
- **`TaskEnvironment`**: The sandboxed execution environment for your agent steps. It configures the container image, hardware resources (CPU, GPU), and secrets (API keys). Think of it as defining "where this code runs."
- **`@env.task`**: Turns any Python function into a remotely-executed step. Each task runs in its own container with the resources you specified. This is the equivalent of a node in LangGraph or n8n.
- **Tasks calling tasks**: A task can `await` other tasks, and each called task gets its own container automatically. No separate workflow decorator needed. The calling task IS your workflow, this is how you build multi-step agentic pipelines.
- **`@flyte.trace`**: Marks helper functions inside a task for fine-grained observability and caching. Each traced call appears as a span in the Flyte dashboard, with its inputs and outputs captured and checkpointed. Use this on your LLM calls, tool executions, and routing decisions to get full visibility into every turn of the agent loop.
> [!TIP]
> See the **Quickstart** for a hands-on walkthrough.
## Next steps
- **Build an agent > Deploy an agent as a service**: Host a FastAPI app, webhook pattern, model serving
- **Build an agent > Building agentic workflows on Flyte**: ReAct pattern, Plan-and-Execute with fan-out, LangGraph integration, and more patterns
## Subpages
- **Build an agent > Building agentic workflows on Flyte**
- **Build an agent > Deploy an agent as a service**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-agent/building-agents ===
# Building agentic workflows on Flyte
Flyte is framework-agnostic: use any Python LLM library (OpenAI SDK, Anthropic SDK, LangChain, LiteLLM, etc.) inside your tasks. The platform provides the production infrastructure layer: sandboxed execution, parallel fan-out, durable checkpointing, and observability for every step of the agent loop.
Two decorators are all you need:
| Decorator | What it does | Think of it as... |
|-----------|-------------|-------------------|
| **`@env.task`** | Runs a function in its own container on Flyte with dedicated resources, dependencies, and secrets | A sandboxed agent step with its own execution environment |
| **`@flyte.trace`** | Marks a helper function for observability, where each call appears as a span in the Flyte dashboard with captured I/O | An observability hook on your LLM calls, tool executions, and routing decisions |
## ReAct pattern: Reason, Act, Observe (no framework needed)
The [ReAct pattern](https://arxiv.org/abs/2210.03629) is the most common agent architecture: the LLM reasons about what to do, calls a tool, observes the result, and repeats until done. This example is implemented directly with flyte:
```
Thought β Action β Observation β repeat until done
```
```python
# agent.py
import json
from pydantic import BaseModel
import flyte
from openai import AsyncOpenAI
env = flyte.TaskEnvironment(
name="agent_env",
image=flyte.Image.from_debian_base(python_version=(3, 13)).with_pip_packages("openai"),
resources=flyte.Resources(cpu=2, memory="2Gi"),
secrets=[flyte.Secret(key="OPENAI_API_KEY")],
)
TOOLS = {"add": lambda a, b: a + b, "multiply": lambda a, b: a * b}
@flyte.trace # each call = a span in Flyte dashboard
async def reason(goal: str, history: str) -> dict:
"""LLM picks a tool or returns a final answer."""
r = await AsyncOpenAI().chat.completions.create(
model="gpt-4.1-nano",
response_format={"type": "json_object"},
messages=[
{"role": "system", "content":
f"Tools: {list(TOOLS)}. Respond JSON: "
'{"thought":..,"tool":..,"args":{}} or {"thought":..,"done":true,"answer":..}'},
{"role": "user", "content": f"Goal: {goal}\n\n{history}\nWhat next?"},
],
)
return json.loads(r.choices[0].message.content)
@flyte.trace
async def act(tool: str, args: dict) -> str:
"""Execute the chosen tool."""
return str(TOOLS[tool](**args))
class AgentResult(BaseModel):
answer: str
steps: int
@env.task # runs in its own sandboxed container
async def react_agent(goal: str, max_steps: int = 10) -> AgentResult:
history = ""
for step in range(1, max_steps + 1): # the agent loop
decision = await reason(goal, history) # Thought
if decision.get("done"):
return AgentResult(answer=str(decision["answer"]), steps=step)
result = await act(decision["tool"], decision["args"]) # Action
history += f"Step {step}: {decision['thought']} -> {decision['tool']}({decision['args']}) = {result}\n" # Observation
return AgentResult(answer="Max steps reached", steps=max_steps)
```
```bash
flyte run agent.py react_agent --goal "What is (12 + 8) * 3?"
# => AgentResult(answer='60', steps=3)
```
**What's happening under the hood:**
- `react_agent` runs in a sandboxed container with only `openai` installed and 2 CPU / 2GB RAM
- Each `reason()` and `act()` call is traced, so you see every LLM call, every tool invocation, and every intermediate result in the Flyte dashboard
- The agent's inputs and final output are durably persisted, letting you inspect any past run end-to-end
- Swap in your own tools (web search, database queries, API calls) by adding to the `TOOLS` dict
> [!TIP]
> See the [Agentic Refinement docs](../advanced-project/agentic-refinement), [Traces docs](../task-programming/traces), and [more patterns (planner, debate, etc.)](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows).
## Plan-and-Execute with parallel fan-out (LangGraph on Flyte)
The [Plan-and-Execute pattern](https://blog.langchain.com/plan-and-execute-agents/) splits a complex query into sub-tasks, fans them out in parallel, then synthesizes the results. This example runs a LangGraph research agent with web search tool calling, and Flyte handles the parallelization, giving each sub-task its own container.
Here's `graph.py`, a LangGraph agent with tool calling (search the web, then summarize):
```python
import flyte
from langchain_openai import ChatOpenAIe
from langchain_core.messages import SystemMessage
from langgraph.graph import StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_community.tools.tavily_search import TavilySearchResults
def build_research_graph(openai_key: str, tavily_key: str):
tools = [TavilySearchResults(max_results=2, tavily_api_key=tavily_key)]
llm = ChatOpenAI(model="gpt-4.1-nano", api_key=openai_key).bind_tools(tools)
@flyte.trace
async def agent(state: MessagesState):
msgs = [SystemMessage(content="Research the topic. Use search, then summarize.")] + state["messages"]
return {"messages": [await llm.ainvoke(msgs)]}
@flyte.trace
async def route(state: MessagesState):
last = state["messages"][-1]
return "tools" if getattr(last, "tool_calls", None) else "__end__"
g = StateGraph(MessagesState)
g.add_node("agent", agent)
g.add_node("tools", ToolNode(tools))
g.set_entry_point("agent")
g.add_conditional_edges("agent", route, {"tools": "tools", "__end__": "__end__"})
g.add_edge("tools", "agent")
return g.compile()
```
And `workflow.py`, which plans topics, fans out research in parallel, and synthesizes:
```python
import os, json, asyncio, flyte
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from graph import build_research_graph
env = flyte.TaskEnvironment(
name="research_env",
image=flyte.Image.from_debian_base(python_version=(3, 13))
.with_pip_packages("openai", "langchain-openai", "langchain-community", "langgraph", "tavily-python"),
resources=flyte.Resources(cpu=2, memory="2Gi"),
secrets=[flyte.Secret(key="OPENAI_API_KEY"), flyte.Secret(key="TAVILY_API_KEY")],
)
@env.task
async def plan(query: str, n: int = 3) -> list[str]:
"""Split query into sub-topics."""
r = await ChatOpenAI(model="gpt-4.1-nano", api_key=os.environ["OPENAI_API_KEY"]).ainvoke(
f"Break into exactly {n} sub-topics. Return ONLY a JSON array of strings, e.g. [\"topic1\", \"topic2\"]. No objects.\n\n{query}")
topics = json.loads(r.content)[:n]
return [t if isinstance(t, str) else str(t.get("sub_topic", t)) for t in topics]
@env.task
async def research(topic: str) -> str:
"""Run LangGraph agent on one topic (each call = separate container)."""
graph = build_research_graph(os.environ["OPENAI_API_KEY"], os.environ["TAVILY_API_KEY"])
result = await graph.ainvoke({"messages": [HumanMessage(content=f"Research: {topic}")]})
return json.dumps({"topic": topic, "report": result["messages"][-1].content})
@env.task
async def synthesize(query: str, reports: list[str]) -> str:
"""Combine sub-reports into a final summary."""
parsed = [json.loads(r) for r in reports]
sections = "\n\n".join(f"## {r['topic']}\n{r['report']}" for r in parsed)
r = await ChatOpenAI(model="gpt-4.1-nano", api_key=os.environ["OPENAI_API_KEY"]).ainvoke(
f"Synthesize reports on: {query}\n\n{sections}\n\nKey takeaways:")
return r.content
@env.task
async def research_workflow(query: str, num_topics: int = 3) -> str:
topics = await plan(query, num_topics)
reports = list(await asyncio.gather(*[research(t) for t in topics])) # parallel fan-out
return await synthesize(query, reports)
```
```bash
flyte run workflow.py research_workflow --query "Impact of storms on travel insurance payouts"
```
**What's happening under the hood:**
```
research_workflow (orchestrator)
βββ plan β LLM breaks query into N sub-topics [container 1]
βββ research(t1) β LangGraph agent loop with web search tools [container 2] β
βββ research(t2) β LangGraph agent loop with web search tools [container 3] β parallel
βββ research(t3) β LangGraph agent loop with web search tools [container 4] β
βββ synthesize β LLM combines reports into final answer [container 5]
```
- **Fan-out:** `asyncio.gather()` launches all research tasks in parallel, each in its own sandboxed container
- **Tool calling inside each research task:** The LangGraph agent calls Tavily web search, observes results, reasons about them, and loops until it has enough information (the inner agentic loop)
- **Observability:** `@flyte.trace` on the LangGraph nodes means every LLM call, every tool call, and every routing decision is visible as a span in the Flyte dashboard
- **Durable checkpointing:** Each task's output is persisted. If `synthesize` fails, re-running skips the completed `plan` and `research` steps (with caching enabled)
## More agentic patterns
Flyte is framework-agnostic, so these patterns work with any LLM library. Each maps to well-known agent architectures:
| Pattern | What it does | When to use it | Link |
|---------|-------------|----------------|------|
| **ReAct** | Reason β Act β Observe loop with tool calling | Single-agent tasks with tools (API calls, search, code execution) | [multi-agent-workflows/react](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows) |
| **Plan-and-Execute** | LLM creates a plan, independent steps fan out in parallel, results are synthesized | Complex queries that decompose into parallel sub-tasks | [multi-agent-workflows/planner](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows) |
| **Evaluator-Optimizer (Reflection)** | Generate β Critique β Refine loop until quality threshold met | Content generation, code generation, any task with clear quality criteria | [Agentic Refinement docs](../advanced-project/agentic-refinement) |
| **Orchestrator-Workers (Manager)** | Supervisor agent delegates to specialist worker agents, reviews quality, requests revisions | Multi-agent systems where sub-tasks require different expertise | [multi-agent-workflows/manager](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows) |
| **Debate** | Multiple agents solve independently, then debate to consensus | High-stakes decisions where diverse reasoning improves accuracy | [multi-agent-workflows/debate](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows) |
| **Sequential (Prompt Chaining)** | Static pipeline of LLM calls, no dynamic routing | Predictable multi-step transformations (extract β validate β format) | [multi-agent-workflows/sequential](https://github.com/unionai/workshops/tree/main/tutorials/multi-agent-workflows) |
## How Flyte's primitives map to the agent stack
If you're coming from LangGraph, CrewAI, OpenAI Agents SDK, or similar frameworks, here's how the concepts you already know translate:
**Your agent loop** is a Python `for`/`while` loop inside an `@env.task`. Each iteration calls `@flyte.trace`-decorated functions for reasoning and tool execution. Flyte doesn't impose a loop structure; you write it in plain Python, which means any pattern (ReAct, reflection, plan-and-execute) works naturally.
**Tool calling** is just calling Python functions. Define your tools as regular functions, decorate them with `@flyte.trace` for observability, and call them from within the agent loop. Use any tool-calling mechanism your LLM SDK provides (OpenAI function calling, Anthropic tool use, LangChain `bind_tools()`). MCP servers can be accessed from within tasks using the MCP Python SDK.
**Parallel fan-out** (LangGraph's `Send()`, n8n's Split in Batches) is `asyncio.gather()`. Each awaited task gets its own container, giving you true parallelism on separate hardware, not just concurrent coroutines.
**State and checkpointing** (LangGraph's Checkpointers, Threads) is automatic. Every task's inputs and outputs are durably persisted. `@flyte.trace` adds sub-step checkpoints within a task. Re-running with caching enabled skips completed steps, Flyte's equivalent of replaying from a checkpoint.
**Routing and conditional logic** (LangGraph's `add_conditional_edges`, n8n's If/Switch nodes) is Python `if/else`. No special API needed.
**Environment isolation** (different dependencies per step) is `TaskEnvironment`. Your LLM step can use `langchain==0.3`; your data step can use `pandas` + GPU. Each gets its own container image.
**Guardrails and validation** are Python code between steps: `if/else` checks, Pydantic validation, structured output parsing, or libraries like NeMo Guardrails. Raise an exception to fail a step and trigger retries.
**Observability:** The Flyte dashboard shows the full execution tree with per-step inputs, outputs, logs, resource usage, and timing. `@flyte.trace` adds spans within a task for fine-grained visibility into individual LLM calls and tool invocations. For LLM-specific metrics (token usage, cost per call), integrate with Langfuse or LangSmith from within your tasks.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-agent/deploy-agent-as-service ===
# Deploy an agent as a service
Flyte makes it straightforward to deploy internal apps (chatbots, dashboards, API endpoints) behind a URL, with no separate infrastructure. This is how you turn an agent into a hosted service that your team (or other agents) can call.
## Chat agent with Gradio
This example takes the ReAct agent from [Building agentic workflows](./building-agents) and wraps it in a Gradio chat interface, deployed as a Flyte app. Users interact in the browser, and each reasoning step streams back in real time.
```python
# app.py
import json
import gradio as gr
import flyte
from flyte.app import AppEnvironment
from openai import AsyncOpenAI
# --- ReAct agent (same pattern as the ReAct agent in Building agentic workflows on Flyte) ---
TOOLS = {"add": lambda a, b: a + b, "multiply": lambda a, b: a * b}
async def reason(goal: str, history: str) -> dict:
"""LLM picks a tool or returns a final answer."""
r = await AsyncOpenAI().chat.completions.create(
model="gpt-4.1-nano",
response_format={"type": "json_object"},
messages=[
{"role": "system", "content":
f"Tools: {list(TOOLS)}. Respond JSON: "
'{"thought":..,"tool":..,"args":{}} or '
'{"thought":..,"done":true,"answer":..}'},
{"role": "user", "content": f"Goal: {goal}\n\n{history}\nWhat next?"},
],
)
return json.loads(r.choices[0].message.content)
async def act(tool: str, args: dict) -> str:
"""Execute the chosen tool."""
return str(TOOLS[tool](**args))
async def react_agent(message: str, history: list):
"""ReAct loop that streams intermediate steps, then the final answer."""
output, trace = "", ""
for step in range(1, 11):
decision = await reason(message, trace)
if decision.get("done"):
yield output + f"\n\n**Answer:** {decision['answer']}"
return
result = await act(decision["tool"], decision["args"])
trace += (
f"Step {step}: {decision['thought']} "
f"-> {decision['tool']}({decision['args']}) = {result}\n"
)
output += (
f"**Step {step}:** {decision['thought']}\n"
f"`{decision['tool']}({decision['args']})` -> `{result}`\n\n"
)
yield output
yield output + "\n\nMax steps reached."
# --- Deploy as a Flyte app ---
serving_env = AppEnvironment(
name="react-agent-chat",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"gradio", "openai",
),
secrets=[flyte.Secret(key="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
port=7860,
)
@serving_env.server
def server():
gr.ChatInterface(
react_agent,
title="ReAct Agent",
examples=["What is (12 + 8) * 3?", "Add 99 and 1, then multiply by 5"],
).launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
flyte.init_from_config()
flyte.serve(serving_env)
CODE0 bash
# Local development
python app.py
# Deploy to Flyte
flyte deploy app.py serving_env
```
Flyte assigns a URL, handles TLS, and auto-scales the app.
**What's happening under the hood:**
- `AppEnvironment` defines the container image, secrets, resources, and port for the app
- `@serving_env.server` marks the function that Flyte calls on remote deployment
- `gr.ChatInterface` with an async generator gives streaming output: users see each reasoning step appear as the agent works
- `requires_auth=False` makes the app publicly accessible; set to `True` to require Flyte authentication
## Other deployment patterns
**FastAPI endpoint:** For API-first agents, use `FastAPIAppEnvironment` to expose your agent behind a REST endpoint that other services or agents can call programmatically.
**Webhook-triggered workflows:** [Deploy a FastAPI app](../build-apps/fastapi-app) that receives webhooks and calls `flyte.run()` on a [remote task](../task-programming/remote-tasks) to kick off longer agentic workflows as background tasks.
**Model serving:** [Serve open-weight LLMs](../build-apps/vllm-app) on GPUs behind an OpenAI-compatible API with `VLLMAppEnvironment` or `SGLangAppEnvironment`.
> [!TIP]
> See [Build Apps](../build-apps/_index), [App usage patterns](../build-apps/app-usage-patterns), and [Configure Apps](../configure-apps/_index) for more details.
> For a hands-on example with a research agent Gradio UI, see [workshops/starter-examples/flyte-local-dev](https://github.com/unionai/workshops/tree/main/tutorials/starter-examples/flyte-local-dev).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/sandboxing ===
# Sandboxing
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
A **sandbox** is an isolated, secure environment where code can run without affecting the host system.
Sandboxes restrict what the executing code can do β limiting filesystem access, blocking network calls, and preventing arbitrary system operations β so that even malicious or buggy code cannot cause harm.
The exact restrictions depend on the sandboxing approach: some sandboxes eliminate dangerous operations entirely, while others provide full capabilities within an isolated, disposable container.
## Why sandboxing matters for AI
LLM-generated code is inherently untrusted.
The model may produce code that is correct and useful, but it can also produce code that is dangerous β and it does so without intent or awareness.
| Risk | Example |
|------|---------|
| Data destruction | `DELETE FROM orders WHERE 1=1` β wipes an entire table |
| Credential exfiltration | Reads environment variables and sends API keys to an external endpoint |
| Infinite loops | `while True: pass` β consumes CPU indefinitely |
| Resource abuse | Spawns thousands of threads or allocates unbounded memory |
| Filesystem damage | `rm -rf /` or overwrites critical configuration files |
| Network abuse | Makes unauthorized API calls, sends spam, or joins a botnet |
Running LLM-generated code without a sandbox means trusting the model to never make these mistakes.
Sandboxing eliminates this trust requirement by making dangerous operations structurally impossible.
## Types of sandboxes
There are three broad approaches to sandboxing LLM-generated code, each with different tradeoffs:
| Type | How it works | Tradeoffs | Examples |
|------|-------------|-----------|----------|
| **One-shot execution** | Code runs to completion in a disposable container, then the container is discarded. Stdout, stderr, and outputs are captured. | Simple, no state reuse. Good for single-turn tasks. | Container tasks, serverless functions |
| **Interactive sessions** | A persistent VM or container where you send commands incrementally and observe results between steps. Sessions last for the lifetime of the VM. | Flexible and multi-turn, but heavier to provision and manage. | E2B, Daytona, fly.io |
| **Programmatic tool calling** | The LLM generates orchestration code that calls a predefined set of tools. The orchestration code runs in a sandbox while the tools run in full containers. | Durable, observable, and secure. Tools are known ahead of time. | Flyte workflow sandboxing |
## What Flyte offers
Flyte provides two complementary sandboxing approaches:
### Workflow sandbox (Monty)
A **sandboxed orchestrator** built on [Monty](https://github.com/pydantic/pydantic-monty), a Rust-based sandboxed Python interpreter.
The sandbox starts in microseconds, runs pure Python control flow, and dispatches heavy work to full container tasks through the Flyte controller.
This enables the **programmatic tool calling** pattern (also known as code mode): LLMs generate Python orchestration code that invokes registered tools, and Flyte executes it safely with full durability, observability, and type checking.
### Code sandbox (container)
A **stateless code sandbox** that runs arbitrary Python scripts or shell commands inside an ephemeral Docker container.
The container is built on demand from declared dependencies, executed once, and discarded.
This is the right choice when you need full Python capabilities β third-party packages, file I/O, shell commands, or any computation that goes beyond pure control flow.
### When to use which
| | Workflow sandbox | Code sandbox |
|---|---|---|
| **Runtime** | Monty (Rust-based Python interpreter) | Ephemeral Docker container |
| **Startup** | Microseconds | Seconds (image build + container spin-up) |
| **Capabilities** | Pure Python control flow only β no imports, no I/O, no network | Full Python environment β any package, any library, full I/O |
| **Use case** | LLM-generated orchestration logic that calls registered tools | Arbitrary computation β data processing, test execution, ETL, shell pipelines |
| **State** | Runs within a worker container process | Stateless β fresh container per invocation |
| **Security model** | Dangerous operations are structurally impossible | Isolated container |
- Use the **workflow sandbox** when you need to run untrusted control flow (loops, conditionals, routing) that dispatches work to known tasks. It starts in microseconds and provides the strongest isolation guarantees.
- Use the **code sandbox** when you need full Python capabilities β third-party packages, file I/O, shell commands, or any computation that goes beyond pure control flow.
### Learn more
- **Sandboxing > Workflow sandboxing in Flyte** β How the Monty-based sandboxed orchestrator works, with examples
- **Sandboxing > Programmatic tool calling for agents** β The concept behind programmatic tool calling and how to build agents that use it
- **Sandboxing > Code sandboxing** β Running arbitrary code and commands in ephemeral containers with `flyte.sandbox.create()`
## Subpages
- **Sandboxing > Workflow sandboxing in Flyte**
- **Sandboxing > Programmatic tool calling for agents**
- **Sandboxing > Code sandboxing**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/sandboxing/workflow-sandboxing-flyte ===
# Workflow sandboxing in Flyte
Flyte provides a sandboxed orchestrator that lets you run pure Python control flow in a secure sandbox while dispatching heavy work to full container tasks.
This enables patterns where LLMs generate orchestration code dynamically, and Flyte executes it safely with full durability and observability.
## Why workflow sandboxing?
Three properties of Flyte make it a natural fit for sandboxed code execution:
1. **Infrastructure on demand**: Flyte spins up containers with specific permissions, secrets, and resources for each task.
2. **LLMs are great at Python**: Models trained on billions of lines of code can reliably generate Python orchestration logic.
3. **Microsecond startup**: The sandbox is powered by [Monty](https://github.com/pydantic/pydantic-monty) (Pydantic's Rust-based Python interpreter), which starts in microseconds without the overhead of VMs or containers.
The result: LLMs generate the orchestration code (control flow, conditionals, loops), and Flyte tasks handle the heavy lifting (data access, computation, external APIs) in full containers.
## How it works
Your generated code runs inside one or more **Monty sandboxes** β lightweight Python interpreters embedded within a **worker container**. Each sandbox can execute pure Python (variables, loops, conditionals, function calls) but has no access to the filesystem, network, imports, or OS. A **bridge layer** acts as a hypervisor between the worker container and the sandboxes, handling opaque IO and routing callable tasks. When your code calls an external task, the bridge dispatches it β either as a method in the outer Python process or as a durable remote call through the Flyte controller (via the Queue Service):
```mermaid
flowchart TB
subgraph worker["Worker Container"]
subgraph bridge["Bridge / Hypervisor"]
IO["Opaque IO: File, Dir, DataFrame"]
subgraph sandbox1["Monty Sandbox 1"]
A1["Your code: loops, variables, conditionals"]
B1["result = add(x, y)"]
end
subgraph sandbox2["Monty Sandbox 2"]
A2["More sandboxed code"]
end
end
end
A1 --> B1
B1 -- "callable task" --> bridge
bridge -- "result" --> B1
IO -. "routed to tasks" .-> bridge
bridge -- "external call" --> QS["Queue Service"]
QS -- "completion" --> bridge
```
Each sandbox sees external tasks as opaque function calls. When your code hits one, Monty **pauses**, and the bridge layer dispatches the task β either directly in the outer Python process or as a remote durable call through the Flyte controller system (Queue Service). Once the call completes, Monty **resumes** with the result. Your code never knows the difference β it just looks like a regular function call that returns a value. Multiple Monty sandboxes can run within the same worker container, each isolated like a lightweight VM.
**Opaque IO types** like `File`, `Dir`, and `DataFrame` are managed by the bridge layer and pass through the sandbox without inspection. Your code can route them between tasks but cannot read or modify their contents.
## Example: sandboxed orchestrator
Use `@env.sandbox.orchestrator` to define a sandboxed task that calls regular worker tasks.
The orchestrator contains only pure Python control flow β all heavy computation runs in worker containers.
```python
import flyte
env = flyte.TaskEnvironment(name="sandboxed-demo")
# Worker tasks β run in their own containers
@env.task
def add(x: int, y: int) -> int:
return x + y
@env.task
def multiply(x: int, y: int) -> int:
return x * y
@env.task
def fib(n: int) -> int:
"""Compute the nth Fibonacci number iteratively."""
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a
# Sandboxed orchestrator β pure Python control flow
@env.sandbox.orchestrator
def pipeline(n: int) -> dict[str, int]:
fib_result = fib(n)
linear_result = add(multiply(n, 2), 5)
total = add(fib_result, linear_result)
return {
"fib": fib_result,
"linear": linear_result,
"total": total,
}
```
When `pipeline` runs, Monty executes the control flow in the sandbox. Each call to `fib`, `multiply`, and `add` pauses the sandbox, runs the worker task in a container, and resumes with the result.
Both `def` and `async def` orchestrators are supported β Monty natively handles `await` expressions.
## Example: dynamic code execution
For cases where the code itself is generated at runtime β from templates, user input, or LLM output β use `orchestrator_from_str()` and `orchestrate_local()`.
### Reusable task from a code string
`orchestrator_from_str()` creates a reusable task template from a Python code string.
The value of the **last expression** becomes the return value.
```python
import flyte
import flyte.sandbox
env = flyte.TaskEnvironment(name="code-string-demo")
@env.task
def add(x: int, y: int) -> int:
return x + y
@env.task
def multiply(x: int, y: int) -> int:
return x * y
# Create a reusable task from a code string
compute_pipeline = flyte.sandbox.orchestrator_from_str(
"""
partial = add(x, y)
multiply(partial, scale)
""",
inputs={"x": int, "y": int, "scale": int},
output=int,
tasks=[add, multiply],
name="compute-pipeline",
)
# flyte.run(compute_pipeline, x=2, y=3, scale=4) β 20
```
### One-shot local execution
`orchestrate_local()` executes a code string and returns the result directly β no task template, no controller.
Use it for quick one-off computations.
```python
result = await flyte.sandbox.orchestrate_local(
"add(x, y) * 2",
inputs={"x": 1, "y": 2},
tasks=[add],
)
# result β 6
```
### Parameterized code generation
Because the code is a string, you can generate it programmatically:
```python
def make_reducer(operation: str) -> flyte.sandbox.CodeTaskTemplate:
"""Create a sandboxed task that reduces a list using the given operation."""
if operation == "sum":
body = """
acc = 0
for v in values:
acc = acc + v
acc
"""
elif operation == "product":
body = """
acc = 1
for v in values:
acc = acc * v
acc
"""
else:
raise ValueError(f"Unknown operation: {operation}")
return flyte.sandbox.orchestrator_from_str(
body,
inputs={"values": list},
output=int,
name=f"reduce-{operation}",
)
sum_task = make_reducer("sum")
product_task = make_reducer("product")
```
## Building agents with programmatic tool calling
The sandboxed orchestrator and `orchestrate_local()` are the foundation for building agents that use **programmatic tool calling** β systems where an LLM generates Python orchestration code, and the sandbox executes it with registered tools.
Because `orchestrate_local()` accepts a plain code string and a list of tool functions, you can wire it into an LLM generate-execute-retry loop: the model writes code, the sandbox runs it, and on failure the error feeds back to the model for correction.
See [Programmatic tool calling for agents](./code-mode) for the full concept, agent implementation patterns, and end-to-end examples.
## Syntax restrictions
Monty enforces strict syntax restrictions to guarantee sandbox safety.
These restrictions are a feature, not a limitation β they ensure that sandboxed code is deterministic and side-effect free.
### Allowed
| Feature | Notes |
|---------|-------|
| Variables and assignment | `x = 1` |
| Arithmetic and comparisons | `x + y`, `x > y` |
| String operations | Concatenation, formatting |
| `if`/`elif`/`else` | Conditional logic |
| `for` loops | Iteration over lists, ranges, dicts |
| `while` loops | Condition-based loops |
| Function definitions (`def`) | Local helper functions |
| `async def` and `await` | Async orchestrators |
| List/dict/tuple literals | `[1, 2, 3]`, `{"key": "value"}` |
| List comprehensions | `[x * 2 for x in items]` |
| `.append()` on lists | Building lists incrementally |
| Subscript reading | `x = d["key"]`, `x = l[0]` |
| External task calls | Calling registered `@env.task` workers |
| `raise` | Raising exceptions |
### Not allowed
| Feature | Workaround |
|---------|------------|
| `import` | All available functions are provided directly |
| Subscript assignment (`d[k] = v`, `l[i] = v`) | Build dicts as literals; use `.append()` for lists |
| Augmented assignment (`x += 1`) | Use `x = x + 1` |
| `class` definitions | Use dicts or tuples |
| `with` statements | Not needed β no resource management in sandbox |
| `try`/`except` | Errors propagate to the controller |
| Walrus operator (`:=`) | Use separate assignment |
| `yield`/`yield from` | Not supported |
| `global`/`nonlocal` | Not supported |
| Set literals/comprehensions | Use lists |
| `del` statements | Not supported |
| `assert` statements | Use `if` + `raise` |
### Type restrictions
- **Primitive types**: `int`, `float`, `str`, `bool`, `bytes`, `None`
- **Collection types**: `list`, `dict`, `tuple` (including generic forms like `list[int]`, `dict[str, float]`)
- **Opaque IO handles**: `File`, `Dir`, `DataFrame` β pass-through only, cannot be inspected in the sandbox
- **Union types**: `Optional[T]` and `Union` of allowed types
- **Not allowed**: Custom classes, dataclasses, Pydantic models, or any user-defined types
## Security model
The sandboxed orchestrator provides security through restriction, not trust:
- **No filesystem access**: Cannot read, write, or list files
- **No network access**: Cannot make HTTP requests, open sockets, or resolve DNS
- **No OS access**: Cannot spawn processes, read environment variables, or access system resources
- **No imports**: Cannot load any Python modules
- **Opaque IO**: `File`, `Dir`, and `DataFrame` values pass through the sandbox without inspection β the sandbox can route them between tasks but cannot read their contents
- **Type-checked boundaries**: Inputs and outputs are validated against declared types at the sandbox boundary
- **Deterministic execution**: The same inputs always produce the same outputs (excluding external task results)
The sandbox runs untrusted code safely because dangerous operations are not just discouraged β they are structurally impossible in the Monty runtime.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/sandboxing/code-mode ===
# Programmatic tool calling for agents
**Programmatic tool calling** (also known as **code mode**) is a pattern where LLMs write executable code instead of making individual tool calls.
Rather than the model emitting a sequence of JSON tool-call objects and the system routing each one, the model generates a single block of code that calls multiple tools, transforms data, and applies logic β all executed in a sandbox.
The key insight: LLMs are trained on billions of lines of code, but only a small amount of synthetic tool-call data.
Code generation is a more natural and reliable output modality for models than structured tool-call schemas.
## Programmatic tool calling vs sequential tool calling
In sequential tool calling, every intermediate result passes through the model's context window.
The model calls one tool, reads the result, decides what to do next, calls another tool, and so on.
Each round-trip costs tokens and latency.
With programmatic tool calling, the model generates a complete program upfront.
The sandbox executes it, and only the final result returns to the model.
| Aspect | Sequential tool calling | Programmatic tool calling |
|--------|-------------|-----------|
| **Output format** | JSON tool-call objects, one at a time | A single block of executable code |
| **Data flow** | Every intermediate result passes through the model | Intermediate results stay in the sandbox |
| **Context overhead** | Grows with each tool call (all results in context) | Fixed β only tool signatures in context |
| **Multi-step logic** | Model re-invoked at every step | Sandbox executes loops, conditionals, transforms |
| **Scaling with tools** | Context grows linearly with number of tool definitions | Tools discovered progressively or loaded on demand |
## Why programmatic tool calling is powerful
### Token efficiency
Sequential tool calling loads all tool definitions into the context window upfront and passes every intermediate result through the model.
Programmatic tool calling reduces this dramatically:
- **98%+ context reduction** reported by Anthropic when using code execution with MCP servers β from 150,000 tokens down to 2,000 tokens for the same task.
- **99.9% reduction** reported by Cloudflare for large APIs β approximately 1,000 tokens with programmatic tool calling versus 1.17 million tokens when exposing each API endpoint as a separate tool.
### Performance
By eliminating round-trips through the model for intermediate steps, programmatic tool calling achieves significant speed improvements.
The sandbox evaluates conditionals, loops, and data transformations locally β no "time to first token" delay for each step.
### Natural programming patterns
Code naturally expresses patterns that are awkward or impossible in tool-call sequences:
- **Loops**: Process a list of items without the model deciding "call this tool again" for each one
- **Conditionals**: Branch on intermediate results without another model invocation
- **Data transformation**: Filter, map, and aggregate data before passing it to the next tool
- **Variable reuse**: Store intermediate results and reference them later
### Progressive tool discovery
Instead of loading hundreds of tool definitions into the context window, programmatic tool calling supports progressive discovery.
The model can search for relevant tools, load only what it needs, and compose them in code.
### Data privacy
Intermediate results stay in the sandbox execution environment.
They never re-enter the model's context window, which means sensitive data (PII, credentials, financial records) can be processed without the model seeing it.
## Example: sequential vs programmatic tool calling
Consider a task: "Analyze sales data, filter for Q4, calculate statistics, and create a chart."
### Sequential tool calling approach
The model makes serial tool calls, with each result passing through the context window:
```
Step 1: Model β tool_call: fetch_data("sales_2024")
Result: [150KB of sales data] β back into model context
Step 2: Model β tool_call: filter_data(data, "month", ">=", "Oct")
Result: [40KB of filtered data] β back into model context
Step 3: Model β tool_call: calculate_statistics(filtered, "revenue")
Result: {"mean": 112000, ...} β back into model context
Step 4: Model β tool_call: create_chart("bar", "Q4 Revenue", ...)
Result: "" β back into model context
```
Four round-trips through the model.
The 150KB dataset enters the context window and stays there.
### Programmatic tool calling approach
The model generates a single code block:
```python
data = fetch_data("sales_2024")
q4_months = ["Oct", "Nov", "Dec"]
q4_data = [row for row in data if row["month"] in q4_months]
stats = calculate_statistics(q4_data, "revenue")
months = []
revenues = []
for row in q4_data:
if row["month"] not in months:
months.append(row["month"])
for month in months:
total = 0
for row in q4_data:
if row["month"] == month:
total = total + row["revenue"]
revenues.append(total)
chart = create_chart("bar", "Q4 Revenue by Month", months, revenues)
{"charts": [chart], "summary": "Q4 stats: " + str(stats)}
```
One model invocation.
The data never re-enters the model's context window.
The sandbox handles the filtering, aggregation, and chart creation locally.
## Example: defining tools
Tools are plain Python functions with type annotations and docstrings.
The agent auto-generates its system prompt from these signatures, so adding a tool requires no other changes.
```python
async def fetch_data(dataset: str) -> list:
"""Fetch tabular data by dataset name.
Available datasets:
- "sales_2024": columns month, region, revenue, units
- "employees": columns name, department, salary, years_exp, performance_rating
- "website_traffic": columns date, page, visitors, bounce_rate, avg_duration
- "inventory": columns product, category, stock, price, supplier
"""
...
async def create_chart(chart_type: str, title: str, labels: list, values: list) -> str:
"""Generate a self-contained Chart.js HTML snippet.
Args:
chart_type: One of "bar", "line", "pie", "doughnut".
title: Chart title displayed above the canvas.
labels: X-axis labels (or slice labels for pie/doughnut).
values: Either a flat list of numbers, or a list of
{"label": str, "data": list[number]} dicts for multi-series.
"""
...
async def calculate_statistics(data: list, column: str) -> dict:
"""Calculate descriptive statistics for a numeric column.
Returns dict with keys: count, mean, median, min, max, std_dev.
"""
...
async def filter_data(data: list, column: str, operator: str, value: object) -> list:
"""Filter rows where column matches the condition.
Operator: one of "==", "!=", ">", ">=", "<", "<=".
"""
...
ALL_TOOLS = {
"fetch_data": fetch_data,
"create_chart": create_chart,
"calculate_statistics": calculate_statistics,
"filter_data": filter_data,
}
```
The `ALL_TOOLS` dict is the single source of truth.
The agent introspects it to build the system prompt, and the sandbox uses it to resolve function calls.
## Example: programmatic tool-calling agent
The `CodeModeAgent` implements the generate-execute-retry loop:
```python
import flyte.sandbox
from _tools import ALL_TOOLS
class CodeModeAgent:
def __init__(self, tools, *, model="claude-sonnet-4-6", max_retries=2):
self._tools = tools
self._model = model
self._max_retries = max_retries
# System prompt auto-generated from tool signatures + docstrings
self.system_prompt = self._build_system_prompt()
async def run(self, message: str, history: list[dict]) -> AgentResult:
messages = [*history, {"role": "user", "content": message}]
# Step 1: LLM generates Python code
code = await generate_code(self._model, self.system_prompt, messages)
# Step 2: Execute in Monty sandbox with registered tools
for attempt in range(1 + self._max_retries):
try:
result = await flyte.sandbox.orchestrate_local(
code,
inputs={"_unused": 0},
tasks=list(self._tools.values()),
)
return AgentResult(code=code, charts=result.get("charts", []),
summary=result.get("summary", ""))
except Exception as exc:
if attempt < self._max_retries:
# Step 3: Feed error back to LLM for retry
code = await generate_code(
self._model, self.system_prompt,
[*messages,
{"role": "assistant", "content": f"```python\n{code}\n```"},
{"role": "user", "content": f"Error: {exc}\nFix the code."}],
)
continue
return AgentResult(code=code, error=str(exc))
```
The pattern:
1. **Generate**: The LLM receives tool signatures and the user's request, and outputs Python code.
2. **Execute**: The code runs in the Monty sandbox. Tool calls pause the sandbox, dispatch to real implementations, and resume with results.
3. **Retry**: If execution fails, the error message is fed back to the LLM, which generates a corrected version. This repeats up to `max_retries` times.
## Example: chat app
Wrap the agent in a FastAPI endpoint to create a conversational analytics assistant:
```python
from _agent import CodeModeAgent
from _tools import ALL_TOOLS
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(title="Chat Data Analytics Agent")
env = FastAPIAppEnvironment(
name="chat-analytics-agent",
app=app,
image=flyte.Image.from_debian_base().with_pip_packages(
"fastapi", "uvicorn", "httpx", "pydantic-monty",
),
secrets=flyte.Secret(key="anthropic-api-key", as_env_var="ANTHROPIC_API_KEY"),
)
agent = CodeModeAgent(tools=ALL_TOOLS, max_retries=2)
@app.post("/api/chat")
async def chat(req: ChatRequest) -> ChatResponse:
result = await agent.run(req.message, req.history)
return ChatResponse(
code=result.code,
charts=result.charts,
summary=result.summary,
error=result.error,
)
```
Users send natural language requests (`"Show me monthly revenue trends for 2024"`), the agent generates analysis code, the sandbox executes it with the registered tools, and the response includes charts and a text summary.
## Example: durable agent
For production workloads, wrap the tools as `@env.task` so the sandbox dispatches them as durable Flyte tasks through the controller.
This gives you execution history, retries, caching, and full observability.
```python
from _agent import CodeModeAgent
from _tools import ALL_TOOLS
import flyte
import flyte.report
env = flyte.TaskEnvironment(
name="llm-code-mode",
secrets=[flyte.Secret(key="anthropic-api-key", as_env_var="ANTHROPIC_API_KEY")],
image=flyte.Image.from_debian_base().with_pip_packages(
"httpx", "pydantic-monty", "unionai-reuse",
),
)
# Wrap each tool as a durable task
@env.task
async def fetch_data(dataset: str) -> list:
return await _tools.fetch_data(dataset)
@env.task
async def create_chart(chart_type: str, title: str, labels: list, values: list) -> str:
return await _tools.create_chart(chart_type, title, labels, values)
# ... wrap remaining tools similarly ...
# Agent uses plain functions for prompt generation,
# @env.task versions for durable sandbox execution
durable_tools = {t.func.__name__: t for t in [fetch_data, create_chart, ...]}
agent = CodeModeAgent(tools=ALL_TOOLS, execution_tools=durable_tools)
@env.task(report=True)
async def analyze(request: str) -> str:
"""Run the code-mode agent and render an HTML report."""
result = await agent.run(request, [])
report_html = build_report(request, result)
await flyte.report.replace.aio(report_html)
await flyte.report.flush.aio()
return result.summary
```
The key difference from the chat app: each tool call goes through the Flyte controller as a durable task.
If `fetch_data` fails, Flyte retries it automatically.
Every tool invocation is recorded and visible in the execution timeline.
Run it with:
```bash
flyte run durable_agent.py analyze \
--request "Show me monthly revenue trends for 2024, broken down by region"
```
## References
- [Code execution with MCP](https://www.anthropic.com/engineering/code-execution-with-mcp) β Anthropic engineering blog on the code execution pattern
- [Code Mode](https://blog.cloudflare.com/code-mode/) β Cloudflare's introduction to code mode for LLM tool calling
- [Code Mode MCP](https://blog.cloudflare.com/code-mode-mcp/) β Cloudflare's server-side code mode implementation
- [Code Mode Protocol](https://github.com/universal-tool-calling-protocol/code-mode) β Open specification for the code mode pattern
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/sandboxing/code-sandboxing ===
# Code sandboxing
`flyte.sandbox.create()` runs arbitrary Python code or shell commands inside an ephemeral, stateless Docker container.
The container is built on demand from declared dependencies, executed once, and discarded.
Each invocation starts from a clean slate β no filesystem state, environment variables, or side effects carry over between runs.
## Execution modes
`flyte.sandbox.create()` supports three mutually exclusive execution modes.
### Auto-IO mode
The default mode. Write only the business logic β Flyte generates the I/O boilerplate automatically.
How it works:
1. Flyte generates an `argparse` preamble that parses declared inputs from CLI arguments.
2. Declared inputs become local variables in scope.
3. After your code runs, Flyte writes declared scalar outputs to `/var/outputs/` automatically.
```python{hl_lines=[2, 4, 6, 11]}
import flyte
import flyte.sandbox
sandbox = flyte.sandbox.create(
name="double",
code="result = x * 2",
inputs={"x": int},
outputs={"result": int},
)
result = await sandbox.run.aio(x=21) # returns 42
```
No imports, no argument parsing, no file writing. The variable `x` is available directly, and the variable `result` is captured automatically because it matches a declared output name.
A more involved example with third-party packages:
```python{hl_lines=["4-9", 12, 20, 24]}
import datetime
_stats_code = """\
import numpy as np
nums = np.array([float(v) for v in values.split(",")])
mean = float(np.mean(nums))
std = float(np.std(nums))
window_end = dt + delta
"""
stats_sandbox = flyte.sandbox.create(
name="numpy-stats",
code=_stats_code,
inputs={
"values": str,
"dt": datetime.datetime,
"delta": datetime.timedelta,
},
outputs={"mean": float, "std": float, "window_end": datetime.datetime},
packages=["numpy"],
)
mean, std, window_end = await stats_sandbox.run.aio(
values="1,2,3,4,5",
dt=datetime.datetime(2024, 1, 1),
delta=datetime.timedelta(days=1),
)
```
When there are multiple outputs, `.run()` returns them as a tuple in declaration order.
### Verbatim mode
Set `auto_io=False` to run a complete Python script with full control over I/O.
Flyte runs the script exactly as written β no injected preamble, no automatic output collection.
Your script must:
- Read inputs from `/var/inputs/` (files are bind-mounted at these paths)
- Write outputs to `/var/outputs/`
```python{hl_lines=["4-9", 12, 17]}
from flyte.io import File
_etl_script = """\
import json, pathlib
payload = json.loads(pathlib.Path("/var/inputs/payload").read_text())
total = sum(payload["values"])
pathlib.Path("/var/outputs/total").write_text(str(total))
"""
etl_sandbox = flyte.sandbox.create(
name="etl-script",
code=_etl_script,
inputs={"payload": File},
outputs={"total": int},
auto_io=False,
)
total = await etl_sandbox.run.aio(payload=payload_file)
```
Use verbatim mode when you need precise control over how inputs are read and outputs are written, or when your script has its own argument parsing.
### Command mode
Run any shell command, binary, or pipeline. Provide `command` instead of `code`.
```python{hl_lines=[5]}
from flyte.io import File
linecount_sandbox = flyte.sandbox.create(
name="line-counter",
command=[
"/bin/bash",
"-c",
"grep -c . /var/inputs/data_file > /var/outputs/line_count || echo 0 > /var/outputs/line_count",
],
inputs={"data_file": File},
outputs={"line_count": str},
)
count = await linecount_sandbox.run.aio(data_file=data_file)
```
Command mode is useful for running test suites, compiled binaries, shell pipelines, or any non-Python workload.
Use `arguments` to pass positional arguments to the command.
File inputs are bind-mounted at `/var/inputs/` and can be referenced in the arguments list:
```python{hl_lines=[4, 5]}
sandbox = flyte.sandbox.create(
name="test-runner",
command=["/bin/bash", "-c", pytest_cmd],
arguments=["/var/inputs/solution.py", "/var/inputs/tests.py"],
inputs={"solution.py": File, "tests.py": File},
outputs={"exit_code": str},
)
```
## Executing a sandbox
Call `.run()` on the sandbox object to build the image and execute.
**Async execution**
```python
result = await sandbox.run.aio(x=21)
```
**Sync execution**
```python
result = sandbox.run(x=21)
```
Both forms build the container image (if not already built), start the container, execute the code or command, collect outputs, and discard the container.
`flyte.sandbox.create()` defines the sandbox configuration and can be called at module level or inside a task. The actual container execution happens when you call `.run()`, which must run inside a Flyte task (either locally or remotely on the cluster).
### Error handling
If the sandbox code fails (non-zero exit code, Python exception, or timeout), `.run()` raises an exception with the error details.
If `retries` is set, Flyte automatically retries the execution before surfacing the error.
If the image build fails due to an invalid package, an `InvalidPackageError` is raised with the package name and the underlying error message.
## Supported types
Inputs and outputs must use one of the following types:
| Category | Types |
| ---------------- | ----------------------------------------- |
| **Primitive** | `int`, `float`, `str`, `bool` |
| **Date/time** | `datetime.datetime`, `datetime.timedelta` |
| **File handles** | `flyte.io.File` |
### How types are handled
**In auto-IO mode:**
- **Primitive and date/time inputs** are injected as local variables with the correct Python type. Flyte generates an `argparse` preamble behind the scenes β your code just uses the variable names directly.
- **`File` inputs** are bind-mounted into the container. The input variable contains the file path as a string (e.g., `"/var/inputs/payload"`), so you can read it with `pathlib.Path(payload).read_text()`.
- **Primitive and date/time outputs** are written to `/var/outputs/` automatically. Just assign the value to a variable matching the declared output name.
- **`File` outputs** are the exception β your code must write the file to `/var/outputs/` manually.
**In verbatim mode:**
- All inputs (including primitives) are available at `/var/inputs/`. Your script reads them directly from the filesystem.
- All outputs must be written to `/var/outputs/` by your script.
**In command mode:**
- `File` inputs are bind-mounted at `/var/inputs/`.
- All outputs must be written to `/var/outputs/` by your command.
## Configuring the container image
### Python packages
Install PyPI packages with `packages`:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="data-analysis",
code="...",
inputs={"data": str},
outputs={"result": str},
packages=["numpy", "pandas>=2.0", "scikit-learn"],
)
```
### System packages
Install system-level (apt) packages with `system_packages`:
```python{hl_lines=[7]}
sandbox = flyte.sandbox.create(
name="image-processor",
code="...",
inputs={"image": File},
outputs={"result": File},
packages=["Pillow"],
system_packages=["libgl1-mesa-glx", "libglib2.0-0"],
)
```
> [!NOTE]
> `gcc`, `g++`, and `make` are included automatically in every sandbox image.
### Additional Dockerfile commands
For advanced image customization, use `additional_commands` to inject arbitrary `RUN` commands into the Dockerfile:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="custom-env",
code="...",
inputs={"x": int},
outputs={"y": int},
additional_commands=["curl -sSL https://example.com/setup.sh | bash"],
)
```
### Pre-built images
Skip the image build entirely by providing a pre-built image URI:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="prebuilt",
code="result = x + 1",
inputs={"x": int},
outputs={"result": int},
image="ghcr.io/my-org/my-sandbox-image:latest",
)
```
### Image configuration
Control the registry and Python version with `ImageConfig`:
```python{hl_lines=["8-12"]}
from flyte.sandbox import ImageConfig
sandbox = flyte.sandbox.create(
name="custom-registry",
code="...",
inputs={"x": int},
outputs={"y": int},
image_config=ImageConfig(
registry="ghcr.io/my-org",
registry_secret="ghcr-credentials",
python_version=(3, 12),
),
)
```
## Runtime configuration
### Resources
Set CPU and memory limits for the container:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="heavy-compute",
code="...",
inputs={"data": str},
outputs={"result": str},
resources=flyte.Resources(cpu=4, memory="8Gi"),
)
```
The default is 1 CPU and 1Gi memory.
### Retries
Automatically retry failed executions:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="flaky-task",
code="...",
inputs={"x": int},
outputs={"y": int},
retries=3,
)
```
### Timeout
Set a maximum execution time in seconds:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="bounded-task",
code="...",
inputs={"x": int},
outputs={"y": int},
timeout=300, # 5 minutes
)
```
### Environment variables
Inject environment variables into the container:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="configured-task",
code="...",
inputs={"x": int},
outputs={"y": int},
env_vars={"LOG_LEVEL": "DEBUG", "FEATURE_FLAG": "true"},
)
```
### Secrets
Mount Flyte secrets into the container:
```python{hl_lines=[6]}
sandbox = flyte.sandbox.create(
name="authenticated-task",
code="...",
inputs={"query": str},
outputs={"result": str},
secrets=[flyte.Secret(key="api-key", as_env_var="API_KEY")],
)
```
### Caching
Control output caching behavior:
```python{hl_lines=["6-8"]}
sandbox = flyte.sandbox.create(
name="cached-task",
code="...",
inputs={"x": int},
outputs={"y": int},
cache="auto", # default β Flyte decides based on inputs
# cache="override" # force re-execution and update cache
# cache="disable" # no caching
)
```
## Deploying a sandbox as a task
Use `.as_task()` to convert a sandbox into a deployable `ContainerTask`.
The returned task has the generated script pre-filled as a default input, so retriggers from the UI only require user-declared inputs.
This pattern is useful when you want to define a sandbox dynamically (for example, with LLM-generated code) and then deploy it as a standalone task that others can trigger from the UI.
```python{hl_lines=[4, 11, "33-38"]}
import flyte
import flyte.sandbox
from flyte.io import File
from flyte.sandbox import sandbox_environment
# sandbox_environment provides the base runtime image for code sandboxes.
# Include it in depends_on so Flyte builds the sandbox runtime before your task runs.
env = flyte.TaskEnvironment(
name="sandbox-demo",
image=flyte.Image.from_debian_base(name="sandbox-demo"),
depends_on=[sandbox_environment],
)
@env.task
async def deploy_sandbox_task() -> str:
# Initialize the Flyte client for in-cluster operations (image building, deployment)
flyte.init_in_cluster()
sandbox = flyte.sandbox.create(
name="deployable-sandbox",
# In auto-IO mode, File inputs become path strings β read with pathlib
code="""\
import json, pathlib
data = json.loads(pathlib.Path(payload).read_text())
total = sum(data["values"])
""",
inputs={"payload": File},
outputs={"total": int},
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# Build the image and get a ContainerTask with the script pre-filled
task = await sandbox.as_task.aio()
# Create a TaskEnvironment from the task and deploy it
deploy_env = flyte.TaskEnvironment.from_task("deployable-sandbox", task)
versions = flyte.deploy(deploy_env)
return versions[0].summary_repr()
```
## End-to-end example
The following example defines sandboxes in all three modes, creates helper tasks, and runs everything in a single pipeline:
```
import datetime
from pathlib import Path
import flyte
import flyte.sandbox
from flyte.io import File
from flyte.sandbox import sandbox_environment
# sandbox_environment provides the base runtime for code sandboxes.
# Include it in depends_on so the sandbox runtime is available when tasks execute.
env = flyte.TaskEnvironment(
name="sandbox-demo",
image=flyte.Image.from_debian_base(name="sandbox-demo"),
depends_on=[sandbox_environment],
)
# Auto-IO mode: pure computation
sum_sandbox = flyte.sandbox.create(
name="sum-to-n",
code="total = sum(range(n + 1)) if conditional else 0",
inputs={"n": int, "conditional": bool},
outputs={"total": int},
)
# Auto-IO mode with packages
_stats_code = """\
import numpy as np
nums = np.array([float(v) for v in values.split(",")])
mean = float(np.mean(nums))
std = float(np.std(nums))
window_end = dt + delta
"""
stats_sandbox = flyte.sandbox.create(
name="numpy-stats",
code=_stats_code,
inputs={
"values": str,
"dt": datetime.datetime,
"delta": datetime.timedelta,
},
outputs={"mean": float, "std": float, "window_end": datetime.datetime},
packages=["numpy"],
)
# Verbatim mode: full script control
_etl_script = """\
import json, pathlib
payload = json.loads(pathlib.Path("/var/inputs/payload").read_text())
total = sum(payload["values"])
pathlib.Path("/var/outputs/total").write_text(str(total))
"""
etl_sandbox = flyte.sandbox.create(
name="etl-script",
code=_etl_script,
inputs={"payload": File},
outputs={"total": int},
auto_io=False,
)
# Command mode: shell pipeline
linecount_sandbox = flyte.sandbox.create(
name="line-counter",
command=[
"/bin/bash",
"-c",
"grep -c . /var/inputs/data_file > /var/outputs/line_count || echo 0 > /var/outputs/line_count",
],
inputs={"data_file": File},
outputs={"line_count": str},
)
@env.task
async def create_text_file() -> File:
path = Path("/tmp/data.txt")
path.write_text("line 1\n\nline 2\n")
return await File.from_local(str(path))
@env.task
async def payload_generator() -> File:
path = Path("/tmp/payload.json")
path.write_text('{"values": [1, 2, 3, 4, 5]}')
return await File.from_local(str(path))
@env.task
async def run_pipeline() -> dict:
# Auto-IO: sum 1..10 = 55
total = await sum_sandbox.run.aio(n=10, conditional=True)
# Auto-IO with numpy
mean, std, window_end = await stats_sandbox.run.aio(
values="1,2,3,4,5",
dt=datetime.datetime(2024, 1, 1),
delta=datetime.timedelta(days=1),
)
# Verbatim ETL
payload = await payload_generator()
etl_total = await etl_sandbox.run.aio(payload=payload)
# Command mode: line count
data_file = await create_text_file()
line_count = await linecount_sandbox.run.aio(data_file=data_file)
return {
"sum_1_to_10": total,
"mean": round(mean, 4),
"std": round(std, 4),
"window_end": window_end.isoformat(),
"etl_sum_1_to_10": etl_total,
"line_count": line_count,
}
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(run_pipeline)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/sandboxing/code_sandbox.py*
## API reference
### `flyte.sandbox.create()`
| Parameter | Type | Description |
| --------------------- | ----------------- | ---------------------------------------------------------- |
| `name` | `str` | Sandbox name. Derives task and image names. |
| `code` | `str` | Python source to run. Mutually exclusive with `command`. |
| `inputs` | `dict[str, type]` | Input type declarations. |
| `outputs` | `dict[str, type]` | Output type declarations. |
| `command` | `list[str]` | Shell command to run. Mutually exclusive with `code`. |
| `arguments` | `list[str]` | Arguments forwarded to `command`. |
| `packages` | `list[str]` | Python packages to install via pip. |
| `system_packages` | `list[str]` | System packages to install via apt. |
| `additional_commands` | `list[str]` | Extra Dockerfile `RUN` commands. |
| `resources` | `flyte.Resources` | CPU and memory limits. Default: 1 CPU, 1Gi memory. |
| `image_config` | `ImageConfig` | Registry and Python version settings. |
| `image_name` | `str` | Explicit image name (overrides auto-generated). |
| `image` | `str` | Pre-built image URI (skips build). |
| `auto_io` | `bool` | Auto-generate I/O wiring. Default: `True`. |
| `retries` | `int` | Number of retries on failure. Default: `0`. |
| `timeout` | `int` | Timeout in seconds. |
| `env_vars` | `dict[str, str]` | Environment variables for the container. |
| `secrets` | `list[Secret]` | Flyte secrets to mount. |
| `cache` | `str` | `"auto"`, `"override"`, or `"disable"`. Default: `"auto"`. |
### Sandbox methods
| Method | Description |
| --------------------------------- | ----------------------------------------------------------------- |
| `sandbox.run(**kwargs)` | Build the image and execute synchronously. Returns typed outputs. |
| `await sandbox.run.aio(**kwargs)` | Async version of `run()`. |
| `sandbox.as_task()` | Build the image and return a deployable `ContainerTask`. |
| `await sandbox.as_task.aio()` | Async version of `as_task()`. |
Both `run()` and `as_task()` accept an optional `image` parameter to provide a pre-built image URI, skipping the build step.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials ===
# Tutorials
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This section contains tutorials that showcase relevant use cases and provide step-by-step instructions on how to implement various features using Flyte and Union.
### **Automatic prompt engineering**
Easily run prompt optimization with real-time observability, traceability, and automatic recovery.
### **GPU-accelerated climate modeling**
Run ensemble atmospheric simulations on H200 GPUs with multi-source data ingestion and real-time extreme event detection.
### **Run LLM-generated code**
Securely execute and iterate on LLM-generated code using a code agent with error reflection and retry logic.
### **Deep research**
Build an agentic workflow for deep research with multi-step reasoning and evaluation.
### **Distributed LLM pretraining**
Pretrain large language models at scale with PyTorch Lightning, FSDP, and H200 GPUs, featuring streaming data and real-time metrics.
### **Hyperparameter optimization**
Run large-scale HPO experiments with zero manual tracking, deterministic results, and automatic recovery.
### **Multi-agent trading simulation**
A multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.
### **Text-to-SQL**
Learn how to turn natural language questions into SQL queries with Flyte and LlamaIndex, and explore prompt optimization in practice.
## Subpages
- **Distributed LLM pretraining**
- **GPU-accelerated climate modeling**
- **Multi-agent trading simulation**
- **Run LLM-generated code**
- **Text-to-SQL**
- **Automatic prompt engineering**
- **Deep research**
- **Hyperparameter optimization**
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/distributed-pretraining ===
# Distributed LLM pretraining
When training large models, infrastructure should not be the hardest part. The real work is in the model architecture, the data, and the hyperparameters. In practice, though, teams often spend weeks just trying to get distributed training to run reliably.
And when it breaks, it usually breaks in familiar ways: out-of-memory crashes, corrupted checkpoints, data loaders that silently fail, or runs that hang with no obvious explanation.
Most distributed training tutorials focus on PyTorch primitives. This one focuses on getting something that actually ships. We go into the technical details, such as how FSDP shards parameters, why gradient clipping behaves differently at scale, and how streaming datasets reduce memory pressure, but always with the goal of building a system that works in production.
Real training jobs need more than a training loop. They need checkpointing, fault tolerance, data streaming, visibility into whatβs happening, and the ability to recover from failures. In this tutorial, we build all of that using Flyte, without having to stand up or manage any additional infrastructure.
> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/pretraining/train.py).
## Overview
We're going to pretrain a GPT-2 style language model from scratch. This involves training on raw text data starting from randomly initialized weights, rather than fine-tuning or adapting a pretrained model. This is the same process used to train the original GPT-2, LLaMA, and most other foundation models.
The model learns by predicting the next token. Given "The cat sat on the", it learns to predict "mat". Do this billions of times across terabytes of text, and the model develops surprisingly sophisticated language understanding. That's pretraining.
The challenge is scale. A 30B parameter model doesn't fit on a single GPU. The training dataset, [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) in our case, is 627 billion tokens. Training runs last for days or even weeks. To make this work, you need:
- **Distributed training**: Split the model across multiple GPUs using [FSDP (Fully Sharded Data Parallel)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- **Data streaming**: Pull training data on-demand instead of downloading terabytes upfront
- **Checkpointing**: Save progress regularly so a failure doesnβt wipe out days of compute
- **Observability**: See what's happening inside a multi-day training run
Weβll build a Flyte pipeline that takes care of all of this, using three tasks with clearly defined responsibilities:
1. **Data preparation**: Tokenizes your dataset and converts it to MDS (MosaicML Data Shard) format for streaming. This Flyte task is cached, so it only needs to be run once and can be reused across runs.
2. **Distributed training**: Runs FSDP across 8 H200 GPUs. Flyte's `Elastic` plugin handles the distributed setup. Checkpoints upload to S3 automatically via Flyte's `File` abstraction.
3. **Real-time reporting**: Streams loss curves and training metrics to Flyte Reports, a live dashboard integrated into the Flyte UI.
Why three separate tasks? Flyte makes this separation efficient:
- **Caching**: The data preparation step runs once. On subsequent runs, Flyte skips it entirely.
- **Resource isolation**: Training uses expensive H200 GPUs only while actively training, while the driver runs on inexpensive CPU instances.
- **Fault boundaries**: If training fails, the data preparation step does not re-run. Training can resume directly from the most recent checkpoint.
## Implementation
Let's walk through the code. We'll start with the infrastructure setup, build the model, then wire everything together into a pipeline.
### Setting up the environment
Every distributed training job needs a consistent environment across all nodes. Flyte handles this with container images:
```
import logging
import math
import os
from pathlib import Path
from typing import Optional
import flyte
import flyte.report
import lightning as L
import numpy as np
import torch
import torch.nn as nn
from flyte.io import Dir, File
from flyteplugins.pytorch.task import Elastic
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The imports tell the story: `flyte` for orchestration, `flyte.report` for live dashboards, `lightning` for training loop management, and `Elastic` from Flyte's PyTorch plugin. This last one is key as it configures PyTorch's distributed launch without you writing any distributed setup code.
```
NUM_NODES = 1
DEVICES_PER_NODE = 8
VOCAB_SIZE = (
50257 # GPT-2 BPE tokenizer vocabulary size (constant across all model sizes)
)
N_POSITIONS = 2048 # Maximum sequence length (constant across all model sizes)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
These constants define the distributed topology. We're using 1 node with 8 GPUs, but you can scale this up by changing `NUM_NODES`. The vocabulary size (50,257 tokens) and sequence length (2,048 tokens) match GPT-2's [Byte Pair Encoding (BPE) tokenizer](https://huggingface.co/learn/llm-course/en/chapter6/5).
```
image = flyte.Image.from_debian_base(
name="distributed_training_h200"
).with_pip_packages(
"transformers==4.57.3",
"datasets==4.4.1",
"tokenizers==0.22.1",
"huggingface-hub==0.34.0",
"mosaicml-streaming>=0.7.0",
"pyarrow==22.0.0",
"flyteplugins-pytorch>=2.0.0b33",
"torch==2.9.1",
"lightning==2.5.6",
"tensorboard==2.20.0",
"sentencepiece==0.2.1",
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
Flyte builds this container automatically when the pipeline is run. All dependencies required for distributed training, including PyTorch, Lightning, the streaming library, and NCCL for GPU communication, are baked in. There's no Dockerfile to maintain and no "works on my machine" debugging.
### Declaring resource requirements
Different parts of the pipeline need different resources. Data tokenization needs CPU and memory. Training needs GPUs. The driver just coordinates. Flyte's `TaskEnvironment` lets you declare exactly what each task needs:
```
data_loading_env = flyte.TaskEnvironment(
name="data_loading_h200",
image=image,
resources=flyte.Resources(cpu=5, memory="28Gi", disk="100Gi"),
env_vars={
"HF_DATASETS_CACHE": "/tmp/hf_cache", # Cache directory for datasets
"TOKENIZERS_PARALLELISM": "true", # Enable parallel tokenization
},
cache="auto",
)
distributed_llm_training_env = flyte.TaskEnvironment(
name="distributed_llm_training_h200",
image=image,
resources=flyte.Resources(
cpu=64,
memory="512Gi",
gpu=f"H200:{DEVICES_PER_NODE}",
disk="1Ti",
shm="16Gi", # Explicit shared memory for NCCL communication
),
plugin_config=Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE),
env_vars={
"TORCH_DISTRIBUTED_DEBUG": "INFO",
"NCCL_DEBUG": "WARN",
},
cache="auto",
)
driver_env = flyte.TaskEnvironment(
name="llm_training_driver",
image=image,
resources=flyte.Resources(cpu=2, memory="4Gi"),
cache="auto",
depends_on=[data_loading_env, distributed_llm_training_env],
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
Let's break down the training environment, since this is where most of the complexity lives:
- **`gpu=f"H200:{DEVICES_PER_NODE}"`**: Flyte provisions exactly 8 H200 GPUs. These have 141GB of memory each, enough to train 30B+ parameter models with FSDP.
- **`shm="16Gi"`**: This allocates explicit shared memory. NCCL (NVIDIA's communication library) uses shared memory for inter-GPU communication on the same node. Without this, you'll see cryptic errors like "NCCL error: unhandled system error", which can be difficult to debug.
- **`Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE)`**: This is Flyte's integration with PyTorch's elastic launch. It handles process spawning (one process per GPU), rank assignment (each process knows its ID), and environment setup (master address, world size). This replaces the boilerplate typically written in shell scripts.
The `driver_env` is intentionally lightweight, using 2 CPUs and 4 GB of memory. Its role is limited to orchestrating tasks and passing data between them, so allocating GPUs here would be unnecessary.
### Model configurations
Training a 1.5B model uses different hyperparameters than training a 65B model. Rather than hardcoding values, we define presets:
```
MODEL_CONFIGS = {
"1.5B": {
"n_embd": 2048,
"n_layer": 24,
"n_head": 16,
"batch_size": 8,
"learning_rate": 6e-4,
"checkpoint_every_n_steps": 10,
"report_every_n_steps": 5,
"val_check_interval": 100,
}, # Good for testing and debugging
"30B": {
"n_embd": 6656,
"n_layer": 48,
"n_head": 52,
"batch_size": 1,
"learning_rate": 1.6e-4,
"checkpoint_every_n_steps": 7500,
"report_every_n_steps": 200,
"val_check_interval": 1000,
},
"65B": {
"n_embd": 8192,
"n_layer": 80,
"n_head": 64,
"batch_size": 1,
"learning_rate": 1.5e-4,
"checkpoint_every_n_steps": 10000,
"report_every_n_steps": 250,
"val_check_interval": 2000,
},
}
def get_model_config(model_size: str) -> dict:
if model_size not in MODEL_CONFIGS:
available = ", ".join(MODEL_CONFIGS.keys())
raise ValueError(f"Unknown model size: {model_size}. Available: {available}")
return MODEL_CONFIGS[model_size]
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
A few things to notice:
- **Batch size decreases with model size**: For a fixed GPU memory budget, larger models consume more memory for parameters, optimizer state, and activations, leaving less room for per-GPU batch size. For example, a 1.5B parameter model may fit a batch size of 8 per GPU, while a 65B model may only fit a batch size of 1. This is typically compensated for using gradient accumulation to maintain a larger effective batch size.
- **Learning rate decreases with model size**: Larger models are more sensitive to optimization instability and typically require lower learning rates. The values here follow empirical best practices used in large-scale language model training, informed by work such as the [Chinchilla study](https://arxiv.org/pdf/2203.15556) on compute-optimal scaling.
- **Checkpoint frequency increases with model size**: Checkpointing a 65B model is expensive (the checkpoint is huge). We do it less often but make sure we don't lose too much progress if something fails.
The 1.5B config is good for testing your setup before committing to a serious training run.
### Building the GPT model
Now for the model itself. We're building a GPT-2 style decoder-only transformer from scratch.
First, the configuration class:
```
class GPTConfig:
"""Configuration for GPT model."""
def __init__(
self,
vocab_size: int = VOCAB_SIZE,
n_positions: int = N_POSITIONS,
n_embd: int = 2048,
n_layer: int = 24,
n_head: int = 16,
n_inner: Optional[int] = None,
activation_function: str = "gelu_new",
dropout: float = 0.1,
layer_norm_epsilon: float = 1e-5,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner if n_inner is not None else 4 * n_embd
self.activation_function = activation_function
self.dropout = dropout
self.layer_norm_epsilon = layer_norm_epsilon
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The key architectural parameters:
- **`n_embd`**: The hidden (embedding) dimension. Larger values increase model capacity but also increase memory and compute requirements.
- **`n_layer`**: The number of transformer blocks. Model depth strongly influences expressiveness and performance.
- **`n_head`**: The number of attention heads. Each head can attend to different patterns or relationships in the input.
- **`n_inner`**: The hidden dimension of the feed-forward network (MLP), typically set to 4x the embedding dimension.
Next, we define a single transformer block:
```
class GPTBlock(nn.Module):
"""Transformer block with causal self-attention."""
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = nn.MultiheadAttention(
config.n_embd,
config.n_head,
dropout=config.dropout,
batch_first=True,
)
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
# Get activation function from config
ACT_FNS = {
"gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # GPT-2 uses approximate GELU
"relu": nn.ReLU(),
"silu": nn.SiLU(),
"swish": nn.SiLU(), # SiLU = Swish
}
act_fn = ACT_FNS.get(config.activation_function, nn.GELU())
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, config.n_inner),
act_fn,
nn.Linear(config.n_inner, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x, causal_mask, key_padding_mask=None):
x_normed = self.ln_1(x)
# Self-attention with causal and padding masks
attn_output, _ = self.attn(
x_normed, # query
x_normed, # key
x_normed, # value
attn_mask=causal_mask, # Causal mask: (seq_len, seq_len)
key_padding_mask=key_padding_mask, # Padding mask: (batch, seq_len)
need_weights=False,
)
x = x + attn_output
# MLP with residual
x = x + self.mlp(self.ln_2(x))
return x
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
Each block has two sub-layers: causal self-attention and a feed-forward MLP. The causal mask ensures the model can only attend to previous tokens in the sequence, so it can't "cheat" by looking at the answer. This is what makes it *autoregressive*.
The full `GPTModel` class (see the complete code) stacks these blocks and adds token and positional embeddings. One important detail is that the input token embedding matrix is shared with the output projection layer (often called [weight tying](https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers)). This reduces the number of parameters by roughly 50 million for typical vocabulary sizes and often leads to better generalization and more stable training.
### The Lightning training module
PyTorch Lightning handles the training loop boilerplate. We wrap our model in a `LightningModule` that defines how to train it:
```
class GPTPreTrainingModule(L.LightningModule):
"""PyTorch Lightning module for GPT pre-training."""
def __init__(
self,
vocab_size: int = 50257,
n_positions: int = 2048,
n_embd: int = 2048,
n_layer: int = 24,
n_head: int = 16,
learning_rate: float = 6e-4,
weight_decay: float = 0.1,
warmup_steps: int = 2000,
max_steps: int = 100000,
):
super().__init__()
self.save_hyperparameters()
config = GPTConfig(
vocab_size=vocab_size,
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
)
self.model = GPTModel(config)
def forward(self, input_ids, attention_mask=None):
return self.model(input_ids, attention_mask)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The `save_hyperparameters()` call is important because it stores all constructor arguments in the checkpoint. This allows the model to be reloaded later without having to manually reconstruct the original configuration.
The training and validation steps implement standard causal language modeling, where the model is trained to predict the next token given all previous tokens in the sequence.
```
def training_step(self, batch, _batch_idx):
# Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
input_ids = batch["input_ids"].long()
labels = batch["labels"].long()
# Get attention mask if present (optional, for padded sequences)
# attention_mask: 1 = real token, 0 = padding
# Note: Current data pipeline creates fixed-length sequences without padding,
# so attention_mask is not present. If using padded sequences, ensure:
# - Padded positions in labels are set to -100 (ignored by cross_entropy)
# - attention_mask marks real tokens (1) vs padding (0)
attention_mask = batch.get("attention_mask", None)
# Forward pass (causal mask is created internally in GPTModel)
logits = self(input_ids, attention_mask=attention_mask)
# Shift logits and labels for causal language modeling
# Before shift: labels[i] = input_ids[i]
# After shift: predict input_ids[i+1] from input_ids[:i+1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Calculate loss
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# Log loss
self.log(
"train/loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
# Calculate and log perplexity only on epoch (exp is costly, less frequent is fine)
perplexity = torch.exp(torch.clamp(loss, max=20.0))
self.log(
"train/perplexity",
perplexity,
on_step=False,
on_epoch=True,
prog_bar=True,
sync_dist=True,
)
return loss
def validation_step(self, batch, _batch_idx):
# Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
input_ids = batch["input_ids"].long()
labels = batch["labels"].long()
# Get attention mask if present (optional, for padded sequences)
attention_mask = batch.get("attention_mask", None)
# Forward pass (causal mask is created internally in GPTModel)
logits = self(input_ids, attention_mask=attention_mask)
# Shift logits and labels
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Calculate loss
loss = nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# Log loss
self.log("val/loss", loss, prog_bar=True, sync_dist=True)
# Calculate and log perplexity (exp is costly, but validation is infrequent so OK)
perplexity = torch.exp(torch.clamp(loss, max=20.0))
self.log("val/perplexity", perplexity, prog_bar=True, sync_dist=True)
return loss
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The model performs a forward pass with a causal (autoregressive) mask created internally, ensuring each token can only attend to earlier positions. To align predictions with targets, the logits and labels are shifted so that the representation at position `i` is used to predict token `i + 1`.
Loss is computed using cross-entropy over the shifted logits and labels. Training loss and perplexity are logged during execution, with metrics synchronized across distributed workers.
The optimizer setup is where a lot of training stability comes from:
```
def configure_optimizers(self):
# Separate parameters into weight decay and no weight decay groups
decay_params = []
no_decay_params = []
for param in self.model.parameters():
if param.requires_grad:
# 1D parameters (biases, LayerNorm) don't get weight decay
# 2D+ parameters (weight matrices) get weight decay
if param.ndim == 1:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer_grouped_parameters = [
{"params": decay_params, "weight_decay": self.hparams.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
# AdamW optimizer
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
betas=(0.9, 0.95),
eps=1e-8,
)
# Learning rate scheduler: warmup + cosine decay
# Warmup: linear increase from 0 to 1.0 over warmup_steps
# Decay: cosine decay from 1.0 to 0.0 over remaining steps
def lr_lambda(current_step):
if current_step < self.hparams.warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, self.hparams.warmup_steps))
# Cosine decay after warmup
progress = (current_step - self.hparams.warmup_steps) / max(
1, self.hparams.max_steps - self.hparams.warmup_steps
)
# Cosine annealing from 1.0 to 0.0 (returns float, not tensor)
return 0.5 * (1.0 + math.cos(progress * math.pi))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
Two important choices here:
1. **Separate weight decay groups**: We only apply weight decay to the weight matrices, not to biases or LayerNorm parameters. This follows the original BERT paper and is now standard practice, as regularizing biases and normalization parameters does not improve performance and can be harmful.
2. **Cosine learning rate schedule with warmup**: We start with a low learning rate, ramp up linearly during warmup (helps stabilize early training when gradients are noisy), then decay following a cosine curve. This schedule outperforms constant or step decay for transformer training.
### Checkpointing for fault tolerance
Training a 30B-parameter model for 15,000 steps can take days. Hardware failures and spot instance preemptions are inevitable, which makes checkpointing essential.
```
class S3CheckpointCallback(L.Callback):
"""
Periodically upload checkpoints to S3 for durability and resumption.
This ensures checkpoints are safely stored in remote storage even if
the training job is interrupted or the instance fails.
"""
def __init__(self, checkpoint_dir: Path, upload_every_n_steps: int):
super().__init__()
self.checkpoint_dir = checkpoint_dir
self.upload_every_n_steps = upload_every_n_steps
self.last_uploaded_step = -1
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
"""Upload checkpoint to S3 every N steps."""
if trainer.global_rank != 0:
return # Only upload from rank 0
current_step = trainer.global_step
# Upload every N steps (aligns with ModelCheckpoint's every_n_train_steps)
if (
current_step % self.upload_every_n_steps == 0
and current_step > self.last_uploaded_step
and current_step > 0
):
try:
# Find the most recent checkpoint file
checkpoint_files = list(self.checkpoint_dir.glob("*.ckpt"))
if not checkpoint_files:
print("No checkpoint files found to upload")
return
# Get the latest checkpoint (by modification time)
latest_checkpoint = max(
checkpoint_files, key=lambda p: p.stat().st_mtime
)
# Upload the checkpoint file directly to S3 using File.from_local_sync
checkpoint_file = File.from_local_sync(str(latest_checkpoint))
print(f"Checkpoint uploaded to S3 at: {checkpoint_file.path}")
self.last_uploaded_step = current_step
except Exception as e:
print(f"Warning: Failed to upload checkpoint to S3: {e}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
This callback runs every `N` training steps and uploads the checkpoint to durable storage. The key line is `File.from_local_sync()` which is a Flyte abstraction for uploading files. There are no blob store credentials to manage and no bucket paths to hardcode. Flyte automatically uses the storage backend configured for your cluster.
The callback only runs on rank 0. In distributed training, all 8 GPUs have identical model states (that's the point of data parallelism). Having all of them upload the same checkpoint would be wasteful and could cause race conditions.
When you restart a failed run, pass the checkpoint via `resume_checkpoint` so training resumes exactly where it left off, including the same step count, optimizer state, and learning rate schedule position.
### Real-time metrics with Flyte Reports
Multi-day training runs need observability. Is the loss decreasing? Did training diverge? Is the learning rate schedule behaving correctly? Flyte Reports let you build live dashboards directly in the UI:
```
class FlyteReportingCallback(L.Callback):
"""Custom Lightning callback to report training metrics to Flyte Report."""
def __init__(self, report_every_n_steps: int = 100):
super().__init__()
self.report_every_n_steps = report_every_n_steps
self.metrics_history = {
"step": [],
"train_loss": [],
"learning_rate": [],
"val_loss": [],
"val_perplexity": [],
}
self.initialized_report = False
self.last_logged_step = -1
def on_train_start(self, trainer, pl_module):
"""Initialize the live dashboard on training start."""
if trainer.global_rank == 0 and not self.initialized_report:
self._initialize_report()
self.initialized_report = True
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The `_initialize_report` method (see complete code) creates an HTML/JavaScript dashboard with interactive charts. The callback then calls `flyte.report.log()` every `N` steps to push new metrics. The charts update in real-time so you can watch your loss curve descend while training runs.
There is no need to deploy Grafana, configure Prometheus, or keep a TensorBoard server running. Using `flyte.report.log()` is sufficient to get live training metrics directly in the Flyte UI.

### Streaming data at scale
Training datasets are massive. SlimPajama contains 627 billion tokens and spans hundreds of gigabytes even when compressed. Downloading the entire dataset to each training node before starting would take hours and waste storage.
Instead, we convert the data to MDS (MosaicML Data Shard) format and stream it during training:
```
@data_loading_env.task
async def load_and_prepare_streaming_dataset(
dataset_name: str,
dataset_config: Optional[str],
max_length: int,
train_split: str,
val_split: Optional[str],
max_train_samples: Optional[int],
max_val_samples: Optional[int],
shard_size_mb: int,
) -> Dir:
"""Tokenize dataset and convert to MDS format for streaming."""
from datasets import load_dataset
from streaming import MDSWriter
from transformers import GPT2TokenizerFast
output_dir = Path("/tmp/streaming_dataset")
output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# MDS schema: what each sample contains
columns = {
"input_ids": "ndarray:int32",
"labels": "ndarray:int32",
}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
This task does three things:
1. **Tokenizes the text** using GPT-2's BPE tokenizer
2. **Concatenates documents** into fixed-length sequences (no padding waste)
3. **Writes shards** to storage in a format optimized for streaming
The task returns a Flyte `Dir` object, which is a reference to the output location. It's not the data itself, just a pointer. When the training task receives this `Dir`, it streams shards on-demand rather than downloading everything upfront.
Flyte caches this task automatically. Run the pipeline twice with the same dataset config, and Flyte skips tokenization entirely on the second run. Change the dataset or sequence length, and it re-runs.
### Distributed training with FSDP
Now we get to the core: actually training the model across multiple GPUs. FSDP is what makes this possible for large models.
```
@distributed_llm_training_env.task(report=True)
def train_distributed_llm(
prepared_dataset: Dir,
resume_checkpoint: Optional[Dir],
vocab_size: int,
n_positions: int,
n_embd: int,
n_layer: int,
n_head: int,
batch_size: int,
num_workers: int,
max_steps: int,
learning_rate: float,
weight_decay: float,
warmup_steps: int,
use_fsdp: bool,
checkpoint_upload_steps: int,
checkpoint_every_n_steps: int,
report_every_n_steps: int,
val_check_interval: int,
grad_accumulation_steps: int = 1,
) -> Optional[Dir]:
# ... setup code ...
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
Notice `report=True` on the task decorator. It enables Flyte Reports for this specific task.
The training task receives the prepared dataset as a `Dir` and streams data directly from storage:
```
# StreamingDataset streams shards from the remote Flyte storage on-demand
# It automatically detects torch.distributed context
# and shards data across GPUs - each rank gets different data automatically
train_dataset = StreamingDataset(
remote=f"{remote_path}/train", # Remote MDS shard location
local=str(local_cache / "train"), # Local cache for downloaded shards
shuffle=True, # Shuffle samples
shuffle_algo="naive", # Shuffling algorithm
batch_size=batch_size, # Used for shuffle buffer sizing
)
# Create validation StreamingDataset if it exists
val_dataset = None
try:
val_dataset = StreamingDataset(
remote=f"{remote_path}/validation",
local=str(local_cache / "validation"),
shuffle=False, # No shuffling for validation
batch_size=batch_size,
)
print(
f"Validation dataset initialized with streaming from: {remote_path}/validation"
)
except Exception as e:
print(f"No validation dataset found: {e}")
# Create data loaders
# StreamingDataset handles distributed sampling internally by detecting
# torch.distributed.get_rank() and torch.distributed.get_world_size()
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
drop_last=True, # Drop incomplete batches for distributed training
collate_fn=mds_collate_fn, # Handle read-only arrays
)
# Create validation loader if validation dataset exists
val_loader = None
if val_dataset is not None:
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
drop_last=False,
collate_fn=mds_collate_fn,
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
`prepared_dataset.path` provides the remote storage path for the dataset. MosaicML's `StreamingDataset` automatically shards data across GPUs so that each rank sees different samples, without requiring a manual distributed sampler. The credentials are already in the environment because Flyte set them up.
FSDP is where the memory magic happens. Instead of each GPU holding a full copy of the model (like Distributed Data Parallel (DDP)), FSDP shards the parameters, gradients, and optimizer states across all GPUs. Each GPU only holds 1/8th of the model. When a layer needs to run, FSDP all-gathers the full parameters, runs the computation, then discards them.
```
# Configure distributed strategy
if use_fsdp:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(
auto_wrap_policy=ModuleWrapPolicy([GPTBlock]),
activation_checkpointing_policy=None,
cpu_offload=False, # H200 has 141GB - no CPU offload needed
state_dict_type="full",
sharding_strategy="FULL_SHARD",
process_group_backend="nccl",
)
else:
from lightning.pytorch.strategies import DDPStrategy
strategy = DDPStrategy(process_group_backend="nccl")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
We wrap at the `GPTBlock` level because each transformer block becomes an FSDP unit. This balances communication overhead (more units = more all-gathers) against memory savings (smaller units = more granular sharding).
One subtle detail: gradient clipping. With FSDP, gradients are sharded across ranks, so computing a global gradient norm would require an expensive all-reduce operation. Instead of norm-based clipping, we use value-based gradient clipping, which clamps each individual gradient element to a fixed range. This can be done independently on each rank with no coordination overhead and is commonly used for large-scale FSDP training.
```
# Initialize trainer
trainer = L.Trainer(
strategy=strategy,
accelerator="gpu",
devices=DEVICES_PER_NODE,
num_nodes=NUM_NODES,
# Training configuration
max_steps=max_steps,
precision="bf16-mixed", # BFloat16 for better numerical stability
# Optimization
gradient_clip_val=1.0,
gradient_clip_algorithm=(
"value" if use_fsdp else "norm"
), # FSDP requires 'value', DDP can use 'norm'
accumulate_grad_batches=grad_accumulation_steps,
# Logging and checkpointing
callbacks=callbacks,
log_every_n_steps=report_every_n_steps,
val_check_interval=val_check_interval,
# Performance
benchmark=True,
deterministic=False,
# Enable gradient checkpointing for memory efficiency
enable_checkpointing=True,
use_distributed_sampler=False, # StreamingDataset handles distributed sampling
)
# Train the model (resume from checkpoint if provided)
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
# Print final results
if trainer.global_rank == 0:
if val_loader is not None:
print(
f"Final validation loss: {trainer.callback_metrics.get('val/loss', 0.0):.4f}"
)
print(
f"Final validation perplexity: {trainer.callback_metrics.get('val/perplexity', 0.0):.4f}"
)
print(f"Checkpoints saved to: {checkpoint_dir}")
return Dir.from_local_sync(output_dir)
return None
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The trainer configuration brings together all the pieces we've discussed:
- **`precision="bf16-mixed"`**: BFloat16 mixed precision training. BF16 has the same dynamic range as FP32 (unlike FP16), so you don't need loss scaling. This is the standard choice for modern GPU training.
- **`gradient_clip_val=1.0`**: Clips gradients to prevent exploding gradients during training. Combined with value-based clipping for FSDP compatibility.
- **`accumulate_grad_batches`**: Accumulates gradients over multiple forward passes before updating weights. This lets us hit a larger effective batch size than what fits in GPU memory.
- **`val_check_interval`**: How often to run validation. For long training runs, you don't want to validate every epoch β that would be too infrequent. Instead, validate every `N` training steps.
- **`use_distributed_sampler=False`**: We disable Lightning's built-in distributed sampler because `StreamingDataset` handles data sharding internally. Using both would cause conflicts.
- **`benchmark=True`**: Enables cuDNN autotuning. PyTorch will benchmark different convolution algorithms on the first batch and pick the fastest one for your specific input sizes.
The trainer then calls `fit()` with the model, data loaders, and optionally a checkpoint path to resume from.
### Tying it together
The pipeline task orchestrates everything:
```
@driver_env.task
async def distributed_llm_pipeline(
model_size: str,
dataset_name: str = "Salesforce/wikitext",
dataset_config: str = "wikitext-103-raw-v1",
max_length: int = 2048,
max_train_samples: Optional[int] = 10000,
max_val_samples: Optional[int] = 1000,
max_steps: int = 100000,
resume_checkpoint: Optional[Dir] = None,
checkpoint_upload_steps: int = 1000,
# Optional overrides (if None, uses model preset defaults)
batch_size: Optional[int] = None,
learning_rate: Optional[float] = None,
use_fsdp: bool = True,
) -> Optional[Dir]:
# Get model configuration
model_config = get_model_config(model_size)
# Use preset values if not overridden
actual_batch_size = (
batch_size if batch_size is not None else model_config["batch_size"]
)
actual_learning_rate = (
learning_rate if learning_rate is not None else model_config["learning_rate"]
)
# Step 1: Load and prepare streaming dataset
prepared_dataset = await load_and_prepare_streaming_dataset(
dataset_name=dataset_name,
dataset_config=dataset_config,
max_length=max_length,
train_split="train",
val_split="validation",
max_train_samples=max_train_samples,
max_val_samples=max_val_samples,
shard_size_mb=64, # 64MB shards
)
# Step 2: Run distributed training
if resume_checkpoint is not None:
print("\nStep 2: Resuming distributed training from checkpoint...")
else:
print("\nStep 2: Starting distributed training from scratch...")
target_global_batch = 256
world_size = NUM_NODES * DEVICES_PER_NODE
effective_per_step = world_size * actual_batch_size
grad_accumulation_steps = max(
1, math.ceil(target_global_batch / max(1, effective_per_step))
)
checkpoint_dir = train_distributed_llm(
prepared_dataset=prepared_dataset,
resume_checkpoint=resume_checkpoint,
vocab_size=VOCAB_SIZE,
n_positions=N_POSITIONS,
n_embd=model_config["n_embd"],
n_layer=model_config["n_layer"],
n_head=model_config["n_head"],
batch_size=actual_batch_size,
num_workers=8,
max_steps=max_steps,
learning_rate=actual_learning_rate,
weight_decay=0.1,
warmup_steps=500,
use_fsdp=use_fsdp,
checkpoint_upload_steps=checkpoint_upload_steps,
checkpoint_every_n_steps=model_config["checkpoint_every_n_steps"],
report_every_n_steps=model_config["report_every_n_steps"],
val_check_interval=model_config["val_check_interval"],
grad_accumulation_steps=grad_accumulation_steps,
)
return checkpoint_dir
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
The flow is straightforward: load the configuration, prepare the data, and run training. Flyte automatically manages the execution graph so data preparation runs first and training waits until it completes. If data preparation is cached from a previous run, training starts immediately.
The gradient accumulation calculation is worth noting. We want a global batch size of 256 (this affects training dynamics), but each GPU can only fit a small batch. With 8 GPUs and batch size 1 each, we need 32 accumulation steps to hit 256.
## Running the pipeline
With everything defined, running is simple:
```
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(
distributed_llm_pipeline,
model_size="30B",
dataset_name="cerebras/SlimPajama-627B",
dataset_config=None,
max_length=2048,
max_train_samples=5_000_000,
max_val_samples=50_000,
max_steps=15_000,
use_fsdp=True,
checkpoint_upload_steps=1000,
)
print(f"Run URL: {run.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*
This configuration is designed for testing and demonstration. Notice `max_train_samples=5_000_000` β that's 5 million samples from a dataset with 627 billion tokens. A tiny fraction, enough to verify everything works without burning through compute.
For a real pretraining run, you would remove this limit by setting `max_train_samples=None`, or increase it significantly. You would also increase `max_steps` to match your compute budget, likely scale to multiple nodes by setting `NUM_NODES=4` or higher, and allocate more resources. The rest of the pipeline remains unchanged.
```bash
flyte create config --endpoint --project --domain --builder remote
uv run train.py
```
When you run this, Flyte:
1. **Builds the container** (cached after first run)
2. **Schedules data prep** on CPU nodes
3. **Waits for data prep** (or skips if cached)
4. **Provisions H200 nodes** and launches distributed training
5. **Streams logs and metrics** to the UI in real-time
Open the Flyte UI to observe the workflow execution. The data preparation task completes first, followed by the training task spinning up. As training begins, the Flyte Reports dashboard starts plotting loss curves. If anything goes wrong, the logs are immediately available in the UI.

If training fails due to an out-of-memory error, a GPU driver error, or a hardware issue, check the logs, fix the problem, and restart the run with `resume_checkpoint` pointing to the most recent checkpoint. Training resumes from where it left off. Flyte tracks the full execution history, so it is easy to see exactly what happened.
## Going further
If you've run through this tutorial, here's where to go next depending on what you're trying to do:
**You want to train on your own data.** The data prep task accepts any HuggingFace dataset with a `text` column. If your data isn't on HuggingFace, you can modify `load_and_prepare_streaming_dataset` to read from S3, local files, or any other source. The key is getting your data into MDS format. Once it's there, the streaming and sharding just works. For production training, look at SlimPajama, [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T), or [The Pile](https://huggingface.co/datasets/EleutherAI/pile) as starting points.
**You want to scale to more GPUs.** Bump `NUM_NODES` and Flyte handles the rest. The main thing to watch is the effective batch size. As you add more GPUs, you may want to reduce gradient accumulation steps to keep the same global batch size, or increase them if you want to experiment with larger batches.
**Your training keeps failing.** Add `retries=3` to your task decorator for automatic retry on transient failures. This handles spot instance preemption, temporary network issues, and the occasional GPU that decides to stop working. Combined with checkpointing, you get fault-tolerant training that can survive most infrastructure hiccups. For persistent failures, the Flyte UI logs are your friend as they capture stdout/stderr from all ranks.
**You want better visibility into what's happening.** We're actively working on surfacing GPU driver logs (xid/sxid errors), memory utilization breakdowns, and NCCL communication metrics directly in the Flyte UI. If you're hitting issues that the current logs don't explain, reach out. Your feedback helps us prioritize what observability features to build next!
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/climate-modeling ===
# GPU-accelerated climate modeling
Climate modeling is hard for two reasons: data and compute. Satellite imagery arrives continuously from multiple sources. Reanalysis datasets have to be pulled from remote APIs. Weather station data shows up in different formats and schemas. And once all of that is finally in one place, running atmospheric physics simulations demands serious GPU compute.
In practice, many climate workflows are held together with scripts, cron jobs, and a lot of manual babysitting. Data ingestion breaks without warning. GPU jobs run overnight with little visibility into what's happening. When something interesting shows up in a simulation, like a developing hurricane, no one notices until the job finishes hours later.
In this tutorial, we build a production-grade climate modeling pipeline using Flyte. We ingest data from three different sources in parallel, combine it with Dask, run ensemble atmospheric simulations on H200 GPUs, detect extreme weather events as they emerge, and visualize everything in a live dashboard. The entire pipeline is orchestrated, cached, and fault-tolerant, so it can run reliably at scale.

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/climate_modeling/simulation.py).
## Overview
We're building an ensemble weather forecasting system. Ensemble forecasting runs the same simulation multiple times with slightly different initial conditions. This quantifies forecast uncertainty. Instead of saying "the temperature will be 25Β°C", we can say "the temperature will be 24-26Β°C with 90% confidence".
The pipeline has five stages:
1. **Data ingestion**: Pull satellite imagery from NOAA GOES, reanalysis data from ERA5, and surface observations from weather stations in parallel.
2. **Preprocessing**: Fuse the datasets, interpolate to a common grid, and run quality control using Dask for distributed computation.
3. **GPU simulation**: Run ensemble atmospheric physics on H200 GPUs. Each ensemble member evolves independently. PyTorch handles the tensor operations; `torch.compile` optimizes the kernels.
4. **Event detection**: Monitor for hurricanes (high wind + low pressure) and heatwaves during simulation. When extreme events are detected, the pipeline can adaptively refine the grid resolution.
5. **Real-time reporting**: Stream metrics to a live Flyte Reports dashboard showing convergence and detected events.
This workflow is a good example of where Flyte shines!
- **Parallel data ingestion**: Three different data sources, three different APIs, all running concurrently. Flyte's async task execution handles this naturally.
- **Resource heterogeneity**: Data ingestion needs CPU and network. Preprocessing needs a Dask cluster. Simulation needs GPUs. Flyte provisions exactly what each stage needs.
- **Caching**: ERA5 data fetches can take minutes. Run the pipeline twice with the same date range, and Flyte skips the fetch entirely.
- **Adaptive workflows**: When a hurricane is detected, we can dynamically refine the simulation. Flyte makes this kind of conditional logic straightforward.
## Implementation
### Dependencies and container image
```
import asyncio
import gc
import io
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Literal
import flyte
import numpy as np
import pandas as pd
import xarray as xr
from flyte.io import File
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
The key imports include `xarray` for multi-dimensional climate data, `flyteplugins.dask` for distributed preprocessing, and `flyte` for orchestration.
```
climate_image = (
flyte.Image.from_debian_base(name="climate_modeling_h200")
.with_apt_packages(
"libnetcdf-dev", # NetCDF for climate data
"libhdf5-dev", # HDF5 for large datasets
"libeccodes-dev", # GRIB format support (ECMWF's native format)
"libudunits2-dev", # Unit conversions
)
.with_pip_packages(
"numpy==2.3.5",
"pandas==2.3.3",
"xarray==2025.11.0",
"torch==2.9.1",
"netCDF4==1.7.3",
"s3fs==2025.10.0",
"aiohttp==3.13.2",
"ecmwf-datastores-client==0.4.1",
"h5netcdf==1.7.3",
"cfgrib==0.9.15.1",
"pyarrow==22.0.0",
"scipy==1.15.1",
"flyteplugins-dask>=2.0.0b33",
"nvidia-ml-py3==7.352.0",
)
.with_env_vars({"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"})
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
Climate data comes in specialized formats such as NetCDF, HDF5, and GRIB. The container image includes libraries to work with all of them, along with PyTorch for GPU computation and the ECMWF client for accessing ERA5 data.
### Simulation parameters and data structures
```
@dataclass
class SimulationParams:
grid_resolution_km: float = 10.0
time_step_minutes: int = 10
simulation_hours: int = 240
physics_model: Literal["WRF", "MPAS", "CAM"] = "WRF"
boundary_layer_scheme: str = "YSU"
microphysics_scheme: str = "Thompson"
radiation_scheme: str = "RRTMG"
# Ensemble forecasting parameters
ensemble_size: int = 800
perturbation_magnitude: float = 0.5
# Convergence criteria for adaptive refinement
convergence_threshold: float = 0.1 # 10% of initial ensemble spread
max_iterations: int = 3
@dataclass
class ClimateMetrics:
timestamp: str
iteration: int
convergence_rate: float
energy_conservation_error: float
max_wind_speed_mps: float
min_pressure_mb: float
detected_phenomena: list[str]
compute_time_seconds: float
ensemble_spread: float
@dataclass
class SimulationSummary:
total_iterations: int
final_resolution_km: float
avg_convergence_rate: float
total_compute_time_seconds: float
hurricanes_detected: int
heatwaves_detected: int
converged: bool
region: str
output_files: list[File]
date_range: list[str, str]
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
`SimulationParams` defines the core behavior of the simulation, including grid resolution, physics schemes, and ensemble size. The default configuration runs 800 ensemble members, which is sufficient to produce statistically meaningful uncertainty estimates.
> [!NOTE]
> Decreasing the grid spacing via `grid_resolution_km` (for example, from 10 km to 5 km) increases grid resolution and significantly increases memory usage because it introduces more data points and intermediate state. Even with 141 GB of H200 GPU memory, high-resolution or adaptively refined simulations may exceed available VRAM, especially when running large ensembles.
>
> To mitigate this, consider reducing the ensemble size, limiting the refined region, running fewer physics variables, or scaling the simulation across more GPUs so memory is distributed more evenly.
`ClimateMetrics` collects diagnostics at each iteration, such as convergence rate, energy conservation, and detected phenomena. These metrics are streamed to the real-time dashboard so you can monitor how the simulation evolves as it runs.
### Task environments
Different stages need different resources. Flyte's `TaskEnvironment` declares exactly what each task requires:
```
gpu_env = flyte.TaskEnvironment(
name="climate_modeling_gpu",
resources=flyte.Resources(
cpu=5,
memory="130Gi",
gpu="H200:1",
),
image=climate_image,
cache="auto",
)
dask_env = flyte.TaskEnvironment(
name="climate_modeling_dask",
plugin_config=Dask(
scheduler=Scheduler(resources=flyte.Resources(cpu=2, memory="6Gi")),
workers=WorkerGroup(
number_of_workers=2,
resources=flyte.Resources(cpu=2, memory="12Gi"),
),
),
image=climate_image,
resources=flyte.Resources(cpu=2, memory="12Gi"), # Head node
cache="auto",
)
cpu_env = flyte.TaskEnvironment(
name="climate_modeling_cpu",
resources=flyte.Resources(cpu=8, memory="64Gi"),
image=climate_image,
cache="auto",
secrets=[
flyte.Secret(key="cds_api_key", as_env_var="ECMWF_DATASTORES_KEY"),
flyte.Secret(key="cds_api_url", as_env_var="ECMWF_DATASTORES_URL"),
],
depends_on=[gpu_env, dask_env],
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
Hereβs what each environment is responsible for:
- **`gpu_env`**: Runs the atmospheric simulations on H200 GPUs. The 130 GB of GPU memory is used to hold the ensemble members in VRAM during execution.
- **`dask_env`**: Provides a distributed Dask cluster for preprocessing. A scheduler and multiple workers handle data fusion and transformation in parallel.
- **`cpu_env`**: Handles data ingestion and orchestration. This environment also includes the secrets required to access the ERA5 API.
The `depends_on` setting on `cpu_env` ensures that Flyte builds the GPU and Dask images first. Once those environments are ready, the orchestration task can launch the specialized simulation and preprocessing tasks.
### Data ingestion: multiple sources in parallel
Climate models need data from multiple sources. Each source has different formats, APIs, and failure modes. We handle them as separate Flyte tasks that run concurrently.
**Satellite imagery from NOAA GOES**
```
@cpu_env.task
async def ingest_satellite_data(region: str, date_range: list[str, str]) -> File:
"""Ingest GOES satellite imagery from NOAA's public S3 buckets."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
This task fetches cloud imagery and precipitable water products from NOAA's public S3 buckets. GOES-16 covers the Atlantic; GOES-17 covers the Pacific. The task selects the appropriate satellite based on region, fetches multiple days in parallel using `asyncio.gather`, and combines everything into a single xarray Dataset.
**ERA5 reanalysis from Copernicus**
```
@cpu_env.task
async def ingest_reanalysis_data(region: str, date_range: list[str, str]) -> File:
"""Fetch ERA5 reanalysis from Copernicus Climate Data Store."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
ERA5 provides 3D atmospheric fields such as temperature, wind, humidity at multiple pressure levels from surface to stratosphere. The ECMWF datastores client handles authentication via Flyte secrets. Each day fetches in parallel, then gets concatenated.
**Surface observations from weather stations:**
```
@cpu_env.task
async def ingest_station_data(
region: str, date_range: list[str, str], max_stations: int = 100
) -> File:
"""Fetch ground observations from NOAA's Integrated Surface Database."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
Ground truth comes from NOAA's Integrated Surface Database. The task filters stations by geographic bounds, fetches hourly observations, and returns a Parquet file for efficient downstream processing.
All three tasks return Flyte `File` objects that hold references to data in blob storage. No data moves until a downstream task actually needs it.
### Preprocessing with Dask
The three data sources need to be combined into a unified atmospheric state. This means:
- Interpolating to a common grid
- Handling missing values
- Merging variables from different sources
- Quality control
This is a perfect fit for Dask to handle lazy evaluation over chunked arrays:
```python
@dask_env.task
async def preprocess_atmospheric_data(
satellite_data: File,
reanalysis_data: File,
station_data: File,
target_resolution_km: float,
) -> File:
```
This task connects to the Dask cluster provisioned by Flyte, loads the datasets with appropriate chunking, merges satellite and reanalysis grids, fills in missing values, and persists the result. Flyte caches the output, so preprocessing only runs when the inputs change.
### GPU-accelerated atmospheric simulation
Now the core: running atmospheric physics on the GPU. Each ensemble member is an independent forecast with slightly perturbed initial conditions.
```
@gpu_env.task
async def run_atmospheric_simulation(
input_data: File,
params: SimulationParams,
partition_id: int = 0,
ensemble_start: int | None = None,
ensemble_end: int | None = None,
) -> tuple[File, ClimateMetrics]:
"""Run GPU-accelerated atmospheric simulation with ensemble forecasting."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
The task accepts a subset of ensemble members (`ensemble_start` to `ensemble_end`). This enables distributing 800 members across multiple GPUs.
The physics step is the computational kernel. It runs advection (wind transport), pressure gradients, Coriolis forces, turbulent diffusion, and moisture condensation:
```
@torch.compile(mode="reduce-overhead")
def physics_step(state_tensor, dt_val, dx_val):
"""Compiled atmospheric physics - 3-4x faster with torch.compile."""
# Advection: transport by wind
temp_grad_x = torch.roll(state_tensor[:, 0], -1, dims=2) - torch.roll(
state_tensor[:, 0], 1, dims=2
)
temp_grad_y = torch.roll(state_tensor[:, 0], -1, dims=3) - torch.roll(
state_tensor[:, 0], 1, dims=3
)
advection = -(
state_tensor[:, 3] * temp_grad_x + state_tensor[:, 4] * temp_grad_y
) / (2 * dx_val)
state_tensor[:, 0] = state_tensor[:, 0] + advection * dt_val
# Pressure gradient with Coriolis
pressure_grad_x = (
torch.roll(state_tensor[:, 1], -1, dims=2)
- torch.roll(state_tensor[:, 1], 1, dims=2)
) / (2 * dx_val)
pressure_grad_y = (
torch.roll(state_tensor[:, 1], -1, dims=3)
- torch.roll(state_tensor[:, 1], 1, dims=3)
) / (2 * dx_val)
coriolis_param = 1e-4 # ~45Β°N latitude
coriolis_u = coriolis_param * state_tensor[:, 4]
coriolis_v = -coriolis_param * state_tensor[:, 3]
state_tensor[:, 3] = (
state_tensor[:, 3] - pressure_grad_x * dt_val * 0.01 + coriolis_u * dt_val
)
state_tensor[:, 4] = (
state_tensor[:, 4] - pressure_grad_y * dt_val * 0.01 + coriolis_v * dt_val
)
# Turbulent diffusion
diffusion_coeff = 10.0
laplacian_temp = (
torch.roll(state_tensor[:, 0], 1, dims=2)
+ torch.roll(state_tensor[:, 0], -1, dims=2)
+ torch.roll(state_tensor[:, 0], 1, dims=3)
+ torch.roll(state_tensor[:, 0], -1, dims=3)
- 4 * state_tensor[:, 0]
) / (dx_val * dx_val)
state_tensor[:, 0] = (
state_tensor[:, 0] + diffusion_coeff * laplacian_temp * dt_val
)
# Moisture condensation
sat_vapor_pressure = 611.2 * torch.exp(
17.67 * state_tensor[:, 0] / (state_tensor[:, 0] + 243.5)
)
condensation = torch.clamp(
state_tensor[:, 2] - sat_vapor_pressure * 0.001, min=0
)
state_tensor[:, 2] = state_tensor[:, 2] - condensation * 0.1
state_tensor[:, 0] = state_tensor[:, 0] + condensation * 2.5e6 / 1005 * dt_val
return state_tensor
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
`@torch.compile(mode="reduce-overhead")` compiles this function into optimized CUDA kernels. Combined with mixed precision (`torch.cuda.amp.autocast`), this runs 3-4x faster than eager PyTorch.
Every 10 timesteps, the simulation checks for extreme events:
- **Hurricanes**: Wind speed > 33 m/s with low pressure
- **Heatwaves**: Temperature anomalies exceeding thresholds
Detected phenomena get logged to the metrics, which flow to the live dashboard.
### Distributing across multiple GPUs
800 ensemble members is a lot for one GPU, so we distribute them:
```
@cpu_env.task
async def run_distributed_simulation_ensemble(
preprocessed_data: File, params: SimulationParams, n_gpus: int
) -> tuple[list[File], list[ClimateMetrics]]:
total_members = params.ensemble_size
members_per_gpu = total_members // n_gpus
# Distribute ensemble members across GPUs
tasks = []
for gpu_id in range(n_gpus):
# Calculate ensemble range for this GPU
ensemble_start = gpu_id * members_per_gpu
# Last GPU gets any remainder members
if gpu_id == n_gpus - 1:
ensemble_end = total_members
else:
ensemble_end = ensemble_start + members_per_gpu
# Launch GPU task with ensemble subset
gpu_task = run_atmospheric_simulation(
preprocessed_data,
params,
gpu_id,
ensemble_start=ensemble_start,
ensemble_end=ensemble_end,
)
tasks.append(gpu_task)
# Execute all GPUs in parallel
results = await asyncio.gather(*tasks)
output_files = [r[0] for r in results]
metrics = [r[1] for r in results]
return output_files, metrics
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
The task splits the ensemble members evenly across the available GPUs, launches the simulation runs in parallel using `asyncio.gather`, and then aggregates the results. With five GPUs, each GPU runs 160 ensemble members. Flyte takes care of scheduling, so GPU tasks start automatically as soon as resources become available.
### The main workflow
Everything comes together in the orchestration task:
```
@cpu_env.task(report=True)
async def adaptive_climate_modeling_workflow(
region: str = "atlantic",
date_range: list[str, str] = ["2024-09-01", "2024-09-10"],
current_params: SimulationParams = SimulationParams(),
enable_multi_gpu: bool = True,
n_gpus: int = 5,
) -> SimulationSummary:
"""Orchestrates multi-source ingestion, GPU simulation, and adaptive refinement."""
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
`report=True` enables Flyte Reports for live monitoring.
```
# Parallel data ingestion from three sources
with flyte.group("data-ingestion"):
satellite_task = ingest_satellite_data(region, date_range)
reanalysis_task = ingest_reanalysis_data(region, date_range)
station_task = ingest_station_data(region, date_range)
satellite_data, reanalysis_data, station_data = await asyncio.gather(
satellite_task,
reanalysis_task,
station_task,
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
`flyte.group("data-ingestion")` visually groups the ingestion tasks in the Flyte UI. Inside the group, three tasks launch concurrently. `asyncio.gather` waits for all three to complete before preprocessing begins.
The workflow then enters an iterative loop:
1. Run GPU simulation (single or multi-GPU)
2. Check convergence by comparing forecasts across iterations
3. Detect extreme events
4. If a hurricane is detected and we haven't refined yet, double the grid resolution
5. Stream metrics to the live dashboard
6. Repeat until converged or max iterations reached
Adaptive mesh refinement is the key feature here. When the simulation detects a hurricane forming, it automatically increases resolution to capture the fine-scale dynamics. This is expensive, so we limit it to one refinement per run.
### Running the pipeline
```
if __name__ == "__main__":
flyte.init_from_config()
run_multi_gpu = flyte.run(adaptive_climate_modeling_workflow)
print(f"Run URL: {run_multi_gpu.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*
Before running, set up ERA5 API credentials:
```bash
flyte create secret cds_api_key
flyte create secret cds_api_url https://cds.climate.copernicus.eu/api
```
Then launch:
```bash
flyte create config --endpoint --project --domain --builder remote
uv run simulation.py
```
The default configuration uses the Atlantic region for September 2024, which is hurricane season.
## Key concepts
### Ensemble forecasting
Weather prediction is inherently uncertain. Small errors in the initial conditions grow over time due to chaotic dynamics, which means a single forecast can only ever be one possible outcome.
Ensemble forecasting addresses this uncertainty by:
- Perturbing the initial conditions within known observational error bounds
- Running many independent forecasts
- Computing the ensemble mean as the most likely outcome and the ensemble spread as a measure of uncertainty
### Adaptive mesh refinement
When a hurricane begins to form, coarse spatial grids are not sufficient to resolve critical features like eyewall dynamics. Adaptive mesh refinement allows the simulation to focus compute where it matters most by:
- Increasing grid resolution, for example from 10 km to 5 km
- Reducing the timestep to maintain numerical stability
- Refining only the regions of interest instead of the entire domain
This approach is computationally expensive, but it is essential for producing accurate intensity forecasts.
### Real-time event detection
Rather than analyzing results after a simulation completes, this pipeline detects significant events as the simulation runs.
The system monitors for conditions such as:
- **Hurricanes**: Wind speeds exceeding 33 m/s (Category 1 threshold) combined with central pressure below 980 mb
- **Heatwaves**: Sustained temperature anomalies over a defined period
Detecting these events in real time enables adaptive responses, such as refining the simulation or triggering alerts, and supports earlier warnings for extreme weather.
## Where to go next
This example is intentionally scoped to keep the ideas clear, but there are several natural ways to extend it for more realistic workloads.
To model different ocean basins, change the `region` parameter to values like `"pacific"` or `"indian"`. The ingestion tasks automatically adjust to pull the appropriate satellite coverage for each region.
To run longer forecasts, increase `simulation_hours` in `SimulationParams`. The default of 240 hours, or 10 days, is typical for medium-range forecasting, but you can run longer simulations if you have the compute budget.
Finally, the physics step here is deliberately simplified. Production systems usually incorporate additional components such as radiation schemes, boundary layer parameterizations, and land surface models. These can be added incrementally as separate steps without changing the overall structure of the pipeline.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/trading-agents ===
# Multi-agent trading simulation
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/trading_agents); based on work by [TauricResearch](https://github.com/TauricResearch/TradingAgents).
This example walks you through building a multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

_Trading agents execution visualization_
## TL;DR
- You'll build a trading firm made up of agents that analyze, argue, and act, modeled with Python functions.
- You'll use the Flyte SDK to orchestrate this world β giving you visibility, retries, caching, and durability.
- You'll learn how to plug in tools, structure conversations, and track decisions across agents.
- You'll see how agents debate, use context, generate reports, and retain memory via vector DBs.
## What is an agent, anyway?
Agentic workflows are a rising pattern for complex problem-solving with LLMs. Think of agents as:
- An LLM (like GPT-4 or Mistral)
- A loop that keeps them thinking until a goal is met
- A set of optional tools they can call (APIs, search, calculators, etc.)
- Enough tokens to reason about the problem at hand
That's it.
You define tools, bind them to an agent, and let it run, reasoning step-by-step, optionally using those tools, until it finishes.
## What's different here?
We're not building yet another agent framework. You're free to use LangChain, custom code, or whatever setup you like.
What we're giving you is the missing piece: a way to run these workflows **reliably, observably, and at scale, with zero rewrites.**
With Flyte, you get:
- Prompt + tool traceability and full state retention
- Built-in retries, caching, and failure recovery
- A native way to plug in your agents; no magic syntax required
## How it works: step-by-step walkthrough
This simulation is powered by a Flyte task that orchestrates multiple intelligent agents working together to analyze a company's stock and make informed trading decisions.

_Trading agents schema_
### Entry point
Everything begins with a top-level Flyte task called `main`, which serves as the entry point to the workflow.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*
This task accepts several inputs:
- the list of analysts to run,
- the number of debate and risk discussion rounds,
- a flag to enable online tools,
- the company you're evaluating,
- and the target trading date.
The most interesting parameter here is the list of analysts to run. It determines which analyst agents will be invoked and shapes the overall structure of the simulation. Based on this input, the task dynamically launches agent tasks, running them in parallel.
The `main` task is written as a regular asynchronous Python function wrapped with Flyte's task decorator. No domain-specific language or orchestration glue is needed β just idiomatic Python, optionally using async for better performance. The task environment is configured once and shared across all tasks for consistency.
```
# {{docs-fragment env}}
import flyte
QUICK_THINKING_LLM = "gpt-4o-mini"
DEEP_THINKING_LLM = "o4-mini"
env = flyte.TaskEnvironment(
name="trading-agents",
secrets=[
flyte.Secret(key="finnhub_api_key", as_env_var="FINNHUB_API_KEY"),
flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY"),
],
image=flyte.Image.from_uv_script("main.py", name="trading-agents", pre=True),
resources=flyte.Resources(cpu="1"),
cache="auto",
)
# {{/docs-fragment env}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/flyte_env.py*
### Analyst agents
Each analyst agent comes equipped with a set of tools and a carefully designed prompt tailored to its specific domain. These tools are modular Flyte tasks β for example, downloading financial reports or computing technical indicators β and benefit from Flyte's built-in caching to avoid redundant computation.
```
from datetime import datetime
import pandas as pd
import tools.interface as interface
import yfinance as yf
from flyte_env import env
from flyte.io import File
@env.task
async def get_reddit_news(
curr_date: str, # Date you want to get news for in yyyy-mm-dd format
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest global news
from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@env.task
async def get_finnhub_news(
ticker: str, # Search query of a company, e.g. 'AAPL, TSM, etc.
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock from Finnhub within a date range
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing news about the company
within the date range from start_date to end_date
"""
end_date_str = end_date
end_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
)
return finnhub_news_result
@env.task
async def get_reddit_stock_info(
ticker: str, # Ticker of a company. e.g. AAPL, TSM
curr_date: str, # Current date you want to get news for
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the company on the given date
"""
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
@env.task
async def get_YFin_data(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data(symbol, start_date, end_date)
return result_data
@env.task
async def get_YFin_data_online(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
return result_data
@env.task
async def cache_market_data(symbol: str, start_date: str, end_date: str) -> File:
data_file = f"{symbol}-YFin-data-{start_date}-{end_date}.csv"
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
return await File.from_local(data_file)
@env.task
async def get_stockstats_indicators_report(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd
look_back_days: int = 30, # how many days to look back
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False, local_data_file
)
return result_stockstats
# {{docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_stockstats_indicators_report_online(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd"
look_back_days: int = 30, # "how many days to look back"
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True, local_data_file
)
return result_stockstats
# {{/docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_finnhub_company_insider_sentiment(
ticker: str, # ticker symbol for the company
curr_date: str, # current date of you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider sentiment information about a company (retrieved
from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 30 days starting at curr_date
"""
data_sentiment = interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30
)
return data_sentiment
@env.task
async def get_finnhub_company_insider_transactions(
ticker: str, # ticker symbol
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider transaction information about a company
(retrieved from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transactions/trading information in the past 30 days
"""
data_trans = interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30
)
return data_trans
@env.task
async def get_simfin_balance_sheet(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
):
"""
Retrieve the most recent balance sheet of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent balance sheet
"""
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return data_balance_sheet
@env.task
async def get_simfin_cashflow(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent cash flow statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent cash flow statement
"""
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
return data_cashflow
@env.task
async def get_simfin_income_stmt(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent income statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent income statement
"""
data_income_stmt = interface.get_simfin_income_statements(ticker, freq, curr_date)
return data_income_stmt
@env.task
async def get_google_news(
query: str, # Query to search with
curr_date: str, # Curr date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news from Google News based on a query and date range.
Args:
query (str): Query to search with
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): How many days to look back
Returns:
str: A formatted string containing the latest news from Google News
based on the query and date range.
"""
google_news_results = interface.get_google_news(query, curr_date, 7)
return google_news_results
@env.task
async def get_stock_news_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest news about the company on the given date.
"""
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@env.task
async def get_global_news_openai(
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
Args:
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest macroeconomic news on the given date.
"""
openai_news_results = interface.get_global_news_openai(curr_date)
return openai_news_results
@env.task
async def get_fundamentals_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest fundamental information about a given stock
on a given date by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest fundamental information
about the company on the given date.
"""
openai_fundamentals_results = interface.get_fundamentals_openai(ticker, curr_date)
return openai_fundamentals_results
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/tools/toolkit.py*
When initialized, an analyst enters a structured reasoning loop (via LangChain), where it can call tools, observe outputs, and refine its internal state before generating a final report. These reports are later consumed by downstream agents.
Here's an example of a news analyst that interprets global events and macroeconomic signals. We specify the tools accessible to the analyst, and the LLM selects which ones to use based on context.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*
Each analyst agent uses a helper function to bind tools, iterate through reasoning steps (up to a configurable maximum), and produce an answer. Setting a max iteration count is crucial to prevent runaway loops. As agents reason, their message history is preserved in their internal state and passed along to the next agent in the chain.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*
Once all analyst reports are complete, their outputs are collected and passed to the next stage of the workflow.
### Research agents
The research phase consists of two agents: a bullish researcher and a bearish one. They evaluate the company from opposing viewpoints, drawing on the analysts' reports. Unlike analysts, they don't use tools. Their role is to interpret, critique, and develop positions based on the evidence.
```
from agents.utils.utils import AgentState, InvestmentDebateState, memory_init
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment bear_researcher}}
@env.task
async def create_bear_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bear_history = investment_debate_state.bear_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bear-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock.
Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators.
Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability,
or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation,
or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning,
exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points
and debating effectively rather than simply listing facts.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate
that demonstrates the risks and weaknesses of investing in the stock.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bear_history=bear_history + "\n" + argument,
bull_history=investment_debate_state.bull_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
# {{/docs-fragment bear_researcher}}
@env.task
async def create_bull_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bull_history = investment_debate_state.bull_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bull-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock.
Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages,
and positive market indicators.
Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing
concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points
and debating effectively rather than just listing data.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate
that demonstrates the strengths of the bull position.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bull_history=bull_history + "\n" + argument,
bear_history=investment_debate_state.bear_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/researchers.py*
To aid reasoning, the agents can also retrieve relevant "memories" from a vector database, giving them richer historical context. The number of debate rounds is configurable, and after a few iterations of back-and-forth between the bull and bear, a research manager agent reviews their arguments and makes a final investment decision.
```
from agents.utils.utils import (
AgentState,
InvestmentDebateState,
RiskDebateState,
memory_init,
)
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment research_manager}}
@env.task
async def create_research_manager(llm: str, state: AgentState) -> AgentState:
history = state.investment_debate_state.history
investment_debate_state = state.investment_debate_state
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="research-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate
this round of debate and make a definitive decision:
align with the bear analyst, the bull analyst,
or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning.
Your recommendationβBuy, Sell, or Holdβmust be clear and actionable.
Avoid defaulting to Hold simply because both sides have valid points;
commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations.
Use these insights to refine your decision-making and ensure you are learning and improving.
Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
Here is the debate:
Debate History:
{history}"""
response = ChatOpenAI(model=llm).invoke(prompt)
new_investment_debate_state = InvestmentDebateState(
judge_decision=response.content,
history=investment_debate_state.history,
bear_history=investment_debate_state.bear_history,
bull_history=investment_debate_state.bull_history,
current_response=response.content,
count=investment_debate_state.count,
)
state.investment_debate_state = new_investment_debate_state
state.investment_plan = response.content
return state
# {{/docs-fragment research_manager}}
@env.task
async def create_risk_manager(llm: str, state: AgentState) -> AgentState:
history = state.risk_debate_state.history
risk_debate_state = state.risk_debate_state
trader_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="risk-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator,
your goal is to evaluate the debate between three risk analystsβRisky,
Neutral, and Safe/Conservativeβand determine the best course of action for the trader.
Your decision must result in a clear recommendation: Buy, Sell, or Hold.
Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid.
Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**,
and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments
and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
{history}
---
Focus on actionable insights and continuous improvement.
Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = ChatOpenAI(model=llm).invoke(prompt)
new_risk_debate_state = RiskDebateState(
judge_decision=response.content,
history=risk_debate_state.history,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Judge",
current_risky_response=risk_debate_state.current_risky_response,
current_safe_response=risk_debate_state.current_safe_response,
current_neutral_response=risk_debate_state.current_neutral_response,
count=risk_debate_state.count,
)
state.risk_debate_state = new_risk_debate_state
state.final_trade_decision = response.content
return state
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/managers.py*
### Trading agent
The trader agent consolidates the insights from analysts and researchers to generate a final recommendation. It synthesizes competing signals and produces a conclusion such as _Buy for long-term growth despite short-term volatility_.
```
from agents.utils.utils import AgentState, memory_init
from flyte_env import env
from langchain_core.messages import convert_to_openai_messages
from langchain_openai import ChatOpenAI
# {{docs-fragment trader}}
@env.task
async def create_trader(llm: str, state: AgentState) -> AgentState:
company_name = state.company_of_interest
investment_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="trader")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, "
f"here is an investment plan tailored for {company_name}. "
"This plan incorporates insights from current technical market trends, "
"macroeconomic indicators, and social media sentiment. "
"Use this plan as a foundation for evaluating your next trading decision.\n\n"
f"Proposed Investment Plan: {investment_plan}\n\n"
"Leverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""You are a trading agent analyzing market data to make investment decisions.
Based on your analysis, provide a specific recommendation to buy, sell, or hold.
End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
to confirm your recommendation.
Do not forget to utilize lessons from past decisions to learn from your mistakes.
Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
},
context,
]
result = ChatOpenAI(model=llm).invoke(messages)
state.messages.append(convert_to_openai_messages(result))
state.trader_investment_plan = result.content
state.sender = "Trader"
return state
# {{/docs-fragment trader}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/trader.py*
### Risk agents
Risk agents comprise agents with different risk tolerances: a risky debater, a neutral one, and a conservative one. They assess the portfolio through lenses like market volatility, liquidity, and systemic risk. Similar to the bull-bear debate, these agents engage in internal discussion, after which a risk manager makes the final call.
```
from agents.utils.utils import AgentState, RiskDebateState
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment risk_debator}}
@env.task
async def create_risky_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
risky_history = risk_debate_state.risky_history
current_safe_response = risk_debate_state.current_safe_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities,
emphasizing bold strategies and competitive advantages.
When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential,
and innovative benefitsβeven when these come with elevated risk.
Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views.
Specifically, respond directly to each point made by the conservative and neutral analysts,
countering with data-driven rebuttals and persuasive reasoning.
Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative.
Here is the trader's decision:
{trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative
and neutral stances to demonstrate why your high-reward perspective offers the best path forward.
Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here are the last arguments from the conservative analyst: {current_safe_response}
Here are the last arguments from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic,
and asserting the benefits of risk-taking to outpace market norms.
Maintain a focus on debating and persuading, not just presenting data.
Challenge each counterpoint to underscore why a high-risk approach is optimal.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risky_history + "\n" + argument,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Risky",
current_risky_response=argument,
current_safe_response=current_safe_response,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
# {{/docs-fragment risk_debator}}
@env.task
async def create_safe_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
safe_history = risk_debate_state.safe_history
current_risky_response = risk_debate_state.current_risky_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets,
minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation,
carefully assessing potential losses, economic downturns, and market volatility.
When evaluating the trader's decision or plan, critically examine high-risk elements,
pointing out where the decision may expose the firm to undue risk and where more cautious
alternatives could secure long-term gains.
Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts,
highlighting where their views may overlook potential threats or fail to prioritize sustainability.
Respond directly to their points, drawing from the following data sources
to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked.
Address each of their counterpoints to showcase why a conservative stance is ultimately the
safest path for the firm's assets.
Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy
over their approaches.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=safe_history + "\n" + argument,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Safe",
current_risky_response=current_risky_response,
current_safe_response=argument,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
@env.task
async def create_neutral_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
neutral_history = risk_debate_state.neutral_history
current_risky_response = risk_debate_state.current_risky_response
current_safe_response = risk_debate_state.current_safe_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective,
weighing both the potential benefits and risks of the trader's decision or plan.
You prioritize a well-rounded approach, evaluating the upsides
and downsides while factoring in broader market trends,
potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts,
pointing out where each perspective may be overly optimistic or overly cautious.
Use insights from the following data sources to support a moderate, sustainable strategy
to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the safe analyst: {current_safe_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky
and conservative arguments to advocate for a more balanced approach.
Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds,
providing growth potential while safeguarding against extreme volatility.
Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to
the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=neutral_history + "\n" + argument,
latest_speaker="Neutral",
current_risky_response=current_risky_response,
current_safe_response=current_safe_response,
current_neutral_response=argument,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/risk_debators.py*
The outcome of the risk manager β whether to proceed with the trade or not β is considered the final decision of the trading simulation.
You can visualize this full pipeline in the Flyte/Union UI, where every step is logged.
Youβll see input/output metadata for each tool and agent task.
Thanks to Flyte's caching, repeated steps are skipped unless inputs change, saving time and compute resources.
### Retaining agent memory with S3 vectors
To help agents learn from past decisions, we persist their memory in a vector store. In this example, we use an [S3 vector](https://aws.amazon.com/s3/features/vectors/) bucket for their simplicity and tight integration with Flyte and Union, but any vector database can be used.
Note: To use the S3 vector store, make sure your IAM role has the following permissions configured:
```
s3vectors:CreateVectorBucket
s3vectors:CreateIndex
s3vectors:PutVectors
s3vectors:GetIndex
s3vectors:GetVectors
s3vectors:QueryVectors
s3vectors:GetVectorBucket
```
After each trade decision, you can run a `reflect_on_decisions` task. This evaluates whether the final outcome aligned with the agent's recommendation and stores that reflection in the vector store. These stored insights can later be retrieved to provide historical context and improve future decision-making.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*
### Running the simulation
First, set up your OpenAI secret (from [openai.com](https://platform.openai.com/api-keys)) and Finnhub API key (from [finnhub.io](https://finnhub.io/)):
```
flyte create secret openai_api_key
flyte create secret finnhub_api_key
```
Then [clone the repo](https://github.com/unionai/unionai-examples), navigate to the `tutorials-v2/trading_agents` directory, and run the following commands:
```
flyte create config --endpoint --project --domain --builder remote
uv run main.py
```
If you'd like to run the `reflect_on_decisions` task instead, comment out the `main` function call and uncomment the `reflect_on_decisions` call in the `__main__` block:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*
Then run:
```
uv run main.py
```
## Why Flyte? _(A quick note before you go)_
You might now be wondering: can't I just build all this with Python and LangChain?
Absolutely. But as your project grows, you'll likely run into these challenges:
1. **Observability**: Agent workflows can feel opaque. You send a prompt, get a response, but what happened in between?
- Were the right tools used?
- Were correct arguments passed?
- How did the LLM reason through intermediate steps?
- Why did it fail?
Flyte gives you a window into each of these stages.
2. **Multi-agent coordination**: Real-world applications often require multiple agents with distinct roles and responsibilities. In such cases, you'll need:
- Isolated state per agent,
- Shared context where needed,
- And coordination β sequential or parallel.
Managing this manually gets fragile, fast. Flyte handles it for you.
3. **Scalability**: Agents and tools might need to run in isolated or containerized environments. Whether you're scaling out to more agents or more powerful hardware, Flyte lets you scale without taxing your local machine or racking up unnecessary cloud bills.
4. **Durability & recovery**: LLM-based workflows are often long-running and expensive. If something fails halfway:
- Do you lose all progress?
- Replay everything from scratch?
With Flyte, you get built-in caching, checkpointing, and recovery, so you can resume where you left off.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/code-agent ===
# Run LLM-generated code
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/code_runner).
This example demonstrates how to run code generated by a large language model (LLM) using a `ContainerTask`.
The agent takes a userβs question, generates Flyte 2 code using the Flyte 2 documentation as context, and runs it in an isolated container.
If the execution fails, the agent reflects on the error and retries
up to a configurable limit until it succeeds.
Using `ContainerTask` ensures that all generated code runs in a secure environment.
This gives you full flexibility to execute arbitrary logic safely and reliably.
## What this example demonstrates
- How to combine LLM generation with programmatic execution.
- How to run untrusted or dynamically generated code securely.
- How to iteratively improve code using agent-like behavior.
## Setting up the agent environment
Let's start by importing the necessary libraries and setting up two environments: one for the container task and another for the agent task.
This example follows the `uv` script format to declare dependencies.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b23",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# ///
```
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
> [!NOTE]
> You can set up access to the OpenAI API using a Flyte secret.
>
> ```
> flyte create secret openai_api_key
> ```
We store the LLM-generated code in a structured format. This allows us to:
- Enforce consistent formatting
- Make debugging easier
- Log and analyze generations systematically
By capturing metadata alongside the raw code, we maintain transparency and make it easier to iterate or trace issues over time.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
We then define a state model to persist the agent's history across iterations. This includes previous messages,
generated code, and any errors encountered.
Maintaining this history allows the agent to reflect on past attempts, avoid repeating mistakes,
and iteratively improve the generated code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
## Retrieve docs
We define a task to load documents from a given URL and concatenate them into a single string.
This string is then used as part of the LLM prompt.
We set `max_depth = 20` to avoid loading an excessive number of documents.
However, even with this limit, the resulting context can still be quite large.
To handle this, we use an LLM (GPT-4 in this case) that supports extended context windows.
> [!NOTE]
> Appending all documents into a single string can result in extremely large contexts, potentially exceeding the LLMβs token limit.
> If your dataset grows beyond what a single prompt can handle, there are a couple of strategies you can use.
> One option is to apply Retrieval-Augmented Generation (RAG), where you chunk the documents, embed them using a model,
> store the vectors in a vector database, and retrieve only the most relevant pieces at inference time.
>
> An alternative approach is to pass references to full files into the prompt, allowing the LLM to decide which files are most relevant based
> on natural-language search over file paths, summaries, or even contents. This method assumes that only a subset of files
> will be necessary for a given task, and the LLM is responsible for navigating the structure and identifying what to read.
> While this can be a lighter-weight solution for smaller datasets, its effectiveness depends on how well the LLM can
> reason over file references and the reliability of its internal search heuristics.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
## Code generation
Next, we define a utility function to construct the LLM chain responsible for generating Python code from user input. This chain leverages
a LangChain `PromptTemplate` to structure the input and an OpenAI chat model to generate well-formed, Flyte 2-compatible Python scripts.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
We then define a `generate` task responsible for producing the code solution.
To improve clarity and testability, the output is structured in three parts:
a short summary of the generated solution, a list of necessary imports,
and the main body of executable code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
A `ContainerTask` then executes this code in an isolated container environment.
It takes the code as input, runs it safely, and returns the programβs output and exit code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
This task verifies that the generated code runs as expected.
It tests the import statements first, then executes the full code.
It records the output and any error messages in the agent state for further analysis.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
If an error occurs, a separate task reflects on the failure and generates a response.
This reflection is added to the agent state to guide future iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
Finally, we define a `main` task that runs the code agent and orchestrates the steps above.
If the code execution fails, we reflect on the error and retry until we reach the maximum number of iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*
## Running the code agent
If things are working properly, you should see output similar to the following:
```
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: PASSED---
---NO CODE TEST FAILURES---
---DECISION: FINISH---
In this solution, we define two tasks using Flyte v2.
The first task, `oomer`, is designed to simulate an out-of-memory (OOM) error by attempting to allocate a large list.
The second task, `failure_recovery`, attempts to execute `oomer` and catches any OOM errors.
If an OOM error is caught, it retries the `oomer` task with increased memory resources.
This pattern demonstrates how to handle resource-related exceptions and dynamically adjust task configurations in Flyte workflows.
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="oom_example", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000 # Simulate OOM
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
...
CODE12
uv run agent.py
```
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/text_to_sql ===
# Text-to-SQL
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/text_to_sql); based on work by [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/workflow/advanced_text_to_sql/).
Data analytics drives modern decision-making, but SQL often creates a bottleneck. Writing queries requires technical expertise, so non-technical stakeholders must rely on data teams. That translation layer slows everyone down.
Text-to-SQL narrows this gap by turning natural language into executable SQL queries. It lowers the barrier to structured data and makes databases accessible to more people.
In this tutorial, we build a Text-to-SQL workflow using LlamaIndex and evaluate it on the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (a benchmark of natural language questions over semi-structured tables). We then explore prompt optimization to see whether it improves accuracy and show how to track prompts and results over time. Along the way, we'll see what worked, what didn't, and what we learned about building durable evaluation pipelines. The pattern here can be adapted to your own datasets and workflows.

## Ingesting data
We start by ingesting the WikiTableQuestions dataset, which comes as CSV files, into a SQLite database. This database serves as the source of truth for our Text-to-SQL pipeline.
```
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
CODE0
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
CODE1
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip(" CODE2
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*
The main `text_to_sql` task orchestrates the pipeline:
- Ingest data
- Build vector indices for each table
- Retrieve relevant tables and rows
- Generate SQL queries with an LLM
- Execute queries and synthesize answers
We use OpenAI GPT models with carefully structured prompts to maximize SQL correctness.
### Vector indexing
We index each table's rows semantically so the model can retrieve relevant examples during SQL generation.
CODE3 ").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*
Each row becomes a text node stored in LlamaIndexβs `VectorStoreIndex`. This lets the system pull semantically similar rows when handling queries.
### Table retrieval and context building
We then retrieve the most relevant tables for a given query and build rich context that combines schema information with sample rows.
CODE4 ").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
CODE5
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip(" CODE6
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*
The SQL generation prompt includes schema, example rows, and formatting rules. After execution, the system returns a final answer.
At this point, we have an end-to-end Text-to-SQL pipeline: natural language questions go in, SQL queries run, and answers come back. To make this workflow production-ready, we leveraged several Flyte 2 capabilities. Caching ensures that repeated steps, like table ingestion or vector indexing, donβt need to rerun unnecessarily, saving time and compute. Containerization provides consistent, reproducible execution across environments, making it easier to scale and deploy. Observability features let us track every step of the pipeline, monitor performance, and debug issues quickly.
While the pipeline works end-to-end, to get a pulse on how it performs across multiple prompts and to gradually improve performance, we can start experimenting with prompt tuning.
Two things help make this process meaningful:
- **A clean evaluation dataset** - so we can measure accuracy against trusted ground truth.
- **A systematic evaluation loop** - so we can see whether prompt changes or other adjustments actually help.
With these in place, the next step is to build a "golden" QA dataset that will guide iterative prompt optimization.
## Building the QA dataset
> [!NOTE]
> The WikiTableQuestions dataset already includes questionβanswer pairs, available in its [GitHub repository](https://github.com/ppasupat/WikiTableQuestions/tree/master/data). To use them for this workflow, you'll need to adapt the data into the required format, but the raw material is there for you to build on.
We generate a dataset of natural language questions paired with executable SQL queries. This dataset acts as the benchmark for prompt tuning and evaluation.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
CODE7
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
CODE8
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
CODE9
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
CODE10
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
CODE11
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
CODE12
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
CODE13
python create_qa_dataset.py
CODE14
python optimizer.py
```
## What we observed
Prompt optimization didn't consistently lift SQL accuracy in this workflow. Accuracy plateaued near the baseline. But the process surfaced valuable lessons about what matters when building LLM-powered systems on real infrastructure.
- **Schema clarity matters**: CSV ingestion produced tables with overlapping names, creating ambiguity. This showed how schema design and metadata hygiene directly affect downstream evaluation.
- **Ground truth needs trust**: Because the dataset came from LLM outputs, noise remained even after filtering. Human review proved essential. Golden datasets need deliberate curation, not just automation.
- **Optimization needs context**: The optimizer couldn't βseeβ which examples failed, limiting its ability to improve. Feeding failures directly risks overfitting. A structured way to capture and reuse evaluation signals is the right long-term path.
Sometimes prompt tweaks alone can lift accuracy, but other times the real bottleneck lives in the data, the schema, or the evaluation loop. The lesson isn't "prompt optimization doesn't work", but that its impact depends on the system around it. Accuracy improves most reliably when prompts evolve alongside clean data, trusted evaluation, and observable feedback loops.
## The bigger lesson
Evaluation and optimization arenβt one-off experiments; theyβre continuous processes. What makes them sustainable isn't a clever prompt, itβs the platform around it.
Systems succeed when they:
- **Observe** failures with clarity β track exactly what failed and why.
- **Remain durable** across iterations β run pipelines that are stable, reproducible, and comparable over time.
That's where Flyte 2 comes in. Prompt optimization is one lever, but it becomes powerful only when combined with:
- Clean, human-validated evaluation datasets.
- Systematic reporting and feedback loops.
**The real takeaway: improving LLM pipelines isn't about chasing the perfect prompt. It's about designing workflows with observability and durability at the core, so that every experiment compounds into long-term progress.**
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/auto_prompt_engineering ===
# Automatic prompt engineering
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/auto_prompt_engineering).
When building with LLMs and agents, the first prompt almost never works. We usually need several iterations before results are useful. Doing this manually is slow, inconsistent, and hard to reproduce.
Flyte turns prompt engineering into a systematic process. With Flyte we can:
- Generate candidate prompts automatically.
- Run evaluations in parallel.
- Track results in real time with built-in observability.
- Recover from failures without losing progress.
- Trace the lineage of every experiment for reproducibility.
And we're not limited to prompts. Just like [hyperparameter optimization](../hpo/_index) in ML, we can tune model temperature, retrieval strategies, tool usage, and more. Over time, this grows into full agentic evaluations, tracking not only prompts but also how agents behave, make decisions, and interact with their environment.
In this tutorial, we'll build an automated prompt engineering pipeline with Flyte, step by step.
## Set up the environment
First, let's configure our task environment.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE0
flyte create secret openai_api_key
```
We also define CSS styles for live HTML reports that track prompt optimization in real time:

## Prepare the evaluation dataset
Next, we define our golden dataset, a set of prompts with known outputs. This dataset is used to evaluate the quality of generated prompts.
For this tutorial, we use a small geometric shapes dataset. To keep it portable, the data prep task takes a CSV file (as a Flyte `File` or a string for files available remotely) and splits it into train and test subsets.
If you already have prompts and outputs in Google Sheets, simply export them as CSV with two columns: `input` and `target`.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE1
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*
Then we define a Flyte `trace` to call the model. Unlike a task, a trace runs within the same runtime as the parent process. Since the model is hosted externally, this keeps the call lightweight but still observable.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE2
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*
## Evaluate prompts
We now define the evaluation process.
Each prompt in the dataset is tested in parallel, but we use a semaphore to control concurrency. A helper function ties together the `generate_and_review` task with an HTML report template. Using `asyncio.gather`, we evaluate multiple prompts at once.
The function measures accuracy as the fraction of responses that match the ground truth. Flyte streams these results to the UI, so you can watch evaluations happen live.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE3
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE4
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE5
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
CODE6
uv run optimizer.py
```

## Why this matters
Most prompt engineering pipelines start as quick scripts or notebooks. They're fine for experimenting, but they're difficult to scale, reproduce, or debug when things go wrong.
With Flyte 2, we get a more reliable setup:
- Run many evaluations in parallel with **From Flyte 1 to 2 > Asynchronous model > Why we need an async model > True parallelism for all workloads** or **From Flyte 1 to 2 > Asynchronous model > Calling sync tasks from async tasks > The `flyte.map` function: Familiar patterns**.
- Watch accuracy improve in real time and link results back to the exact dataset, prompt, and model config used.
- Resume cleanly after failures without rerunning everything from scratch.
- Reuse the same pattern to tune other parameters like temperature, retrieval depth, or agent strategies, not just prompts.
## Next steps
You now have a working automated prompt engineering pipeline. Hereβs how you can take it further:
- **Optimize beyond prompts**: Tune temperature, retrieval strategies, or tool usage just like prompts.
- **Expand evaluation metrics**: Add latency, cost, robustness, or diversity alongside accuracy.
- **Move toward agentic evaluation**: Instead of single prompts, test how agents plan, use tools, and recover from failures in long-horizon tasks.
With this foundation, prompt engineering becomes repeatable, observable, and scalable, ready for production-grade LLM and agent systems.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/deep-research ===
# Deep research
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/deep_research_agent); based on work by [Together AI](https://github.com/togethercomputer/open_deep_research).
This example demonstrates how to build an agentic workflow for deep researchβa multi-step reasoning system that mirrors how a human researcher explores, analyzes, and synthesizes information from the web.
Deep research refers to the iterative process of thoroughly investigating a topic: identifying relevant sources, evaluating their usefulness, refining the research direction, and ultimately producing a well-structured summary or report. It's a long-running task that requires the agent to reason over time, adapt its strategy, and chain multiple steps together, making it an ideal fit for an agentic architecture.
In this example, we use:
- [Tavily](https://www.tavily.com/) to search for and retrieve high-quality online resources.
- [LiteLLM](https://litellm.ai/) to route LLM calls that perform reasoning, evaluation, and synthesis.
The agent executes a multi-step trajectory:
- Parallel search across multiple queries.
- Evaluation of retrieved results.
- Adaptive iteration: If results are insufficient, it formulates new research queries and repeats the search-evaluate cycle.
- Synthesis: After a fixed number of iterations, it produces a comprehensive research report.
What makes this workflow compelling is its dynamic, evolving nature. The agent isn't just following a fixed plan; it's making decisions in context, using multiple prompts and reasoning steps to steer the process.
Flyte is uniquely well-suited for this kind of system. It provides:
- Structured composition of dynamic reasoning steps
- Built-in parallelism for faster search and evaluation
- Traceability and observability into each step and iteration
- Scalability for long-running or compute-intensive workloads

Throughout this guide, we'll show how to design this workflow using the Flyte SDK, and how to unlock the full potential of agentic development with tools you already know and trust.
## Setting up the environment
Let's begin by setting up the task environment. We define the following components:
- Secrets for Together and Tavily API keys
- A custom image with required Python packages and apt dependencies (`pandoc`, `texlive-xetex`)
- External YAML file with all LLM prompts baked into the container
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*
The Python packages are declared at the top of the file using the `uv` script style:
CODE2
## Generate research queries
This task converts a user prompt into a list of focused queries. It makes two LLM calls to generate a high-level research plan and parse that plan into atomic search queries.
CODE3 "):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find(" CODE4 "):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*
LLM calls use LiteLLM, and each is wrapped with `flyte.trace` for observability:
CODE5
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/libs/utils/llms.py*
> [!NOTE]
> We use `flyte.trace` to track intermediate steps within a task, like LLM calls or specific function executions. This lightweight decorator adds observability with minimal overhead and is especially useful for inspecting reasoning chains during task execution.
## Search and summarize
We submit each research query to Tavily and summarize the results using an LLM. We run all summarization tasks with `asyncio.gather`, which signals to Flyte that these tasks can be distributed across separate compute resources.
CODE6 "):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find(" CODE7 "):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
CODE8
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith(" CODE9 "))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith(" CODE10
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*
## Filter results
In this step, we evaluate the relevance of search results and rank them. This task returns the most useful sources for the final synthesis.
CODE11 "):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find(" CODE12 "):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
CODE13
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith(" CODE14 "))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith(" CODE15
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*
## Orchestration
Next, we define a `research_topic` task to orchestrate the entire deep research workflow. It runs the core stages in sequence: generating research queries, performing search and summarization, evaluating the completeness of results, and producing the final report.
CODE16 "):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find(" CODE17 "):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*
The `main` task wraps this entire pipeline and adds report generation in HTML format as the final step.
It also serves as the main entry point to the workflow, allowing us to pass in all configuration parameters, including which LLMs to use at each stage.
This flexibility lets us mix and match models for planning, summarization, and final synthesis, helping us optimize for both cost and quality.
CODE18 "):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find(" CODE19 "):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
CODE20
flyte create secret TOGETHER_API_KEY <>
flyte create secret TAVILY_API_KEY <>
CODE21
uv run agent.py
CODE22
brew install pandoc
brew install basictex # restart your terminal after install
export TOGETHER_API_KEY=<>
export TAVILY_API_KEY=<>
uv run agent.py
CODE23
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "weave==0.51.51",
# "datasets==3.6.0",
# "huggingface-hub==0.32.6",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# ]
# ///
import os
import weave
from agent import research_topic
from datasets import load_dataset
from huggingface_hub import login
from libs.utils.log import AgentLogger
from litellm import completion
import flyte
logging = AgentLogger()
weave.init(project_name="deep-researcher")
env = flyte.TaskEnvironment(name="deep-researcher-eval")
@weave.op
def llm_as_a_judge_scoring(answer: str, output: str, question: str) -> bool:
prompt = f"""
Given the following question and answer, evaluate the answer against the correct answer:
{question}
{output}
{answer}
Note that the agent answer might be a long text containing a lot of information or it might be a short answer.
You should read the entire text and think if the agent answers the question somewhere
in the text. You should try to be flexible with the answer but careful.
For example, answering with names instead of name and surname is fine.
The important thing is that the answer of the agent either contains the correct answer or is equal to
the correct answer.
The agent answer is correct because I can read that ....
1
Otherwise, return
The agent answer is incorrect because there is ...
0
"""
messages = [
{
"role": "system",
"content": "You are an helpful assistant that returns a number between 0 and 1.",
},
{"role": "user", "content": prompt},
]
answer = (
completion(
model="together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
messages=messages,
max_tokens=1000,
temperature=0.0,
)
.choices[0] # type: ignore
.message["content"] # type: ignore
)
return bool(int(answer.split("")[1].split("")[0].strip()))
def authenticate_huggingface():
"""Authenticate with Hugging Face Hub using token from environment variable."""
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable not set. "
"Please set it with your token from https://huggingface.co/settings/tokens"
)
try:
login(token=token)
print("Successfully authenticated with Hugging Face Hub")
except Exception as e:
raise RuntimeError(f"Failed to authenticate with Hugging Face Hub: {e!s}")
@env.task
async def load_questions(
dataset_names: list[str] | None = None,
) -> list[dict[str, str]]:
"""
Load questions from the specified Hugging Face dataset configurations.
Args:
dataset_names: List of dataset configurations to load
Options:
"smolagents:simpleqa",
"hotpotqa",
"simpleqa",
"together-search-bench"
If None, all available configurations except hotpotqa will be loaded
Returns:
List of question-answer pairs
"""
if dataset_names is None:
dataset_names = ["smolagents:simpleqa"]
all_questions = []
# Authenticate with Hugging Face Hub (once and for all)
authenticate_huggingface()
for dataset_name in dataset_names:
print(f"Loading dataset: {dataset_name}")
try:
if dataset_name == "together-search-bench":
# Load Together-Search-Bench dataset
dataset_path = "togethercomputer/together-search-bench"
ds = load_dataset(dataset_path)
if "test" in ds:
split_data = ds["test"]
else:
print(f"No 'test' split found in dataset at {dataset_path}")
continue
for i in range(len(split_data)):
item = split_data[i]
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": item.get("dataset", "together-search-bench"),
}
all_questions.append(question_data)
print(f"Loaded {len(split_data)} questions from together-search-bench dataset")
continue
elif dataset_name == "hotpotqa":
# Load HotpotQA dataset (using distractor version for validation)
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True)
split_name = "validation"
elif dataset_name == "simpleqa":
ds = load_dataset("basicv8vc/SimpleQA")
split_name = "test"
else:
# Strip "smolagents:" prefix when loading the dataset
actual_dataset = dataset_name.split(":")[-1]
ds = load_dataset("smolagents/benchmark-v1", actual_dataset)
split_name = "test"
except Exception as e:
print(f"Failed to load dataset {dataset_name}: {e!s}")
continue # Skip this dataset if it fails to load
print(f"Dataset structure for {dataset_name}: {ds}")
print(f"Available splits: {list(ds)}")
split_data = ds[split_name] # type: ignore
for i in range(len(split_data)):
item = split_data[i]
if dataset_name == "hotpotqa":
# we remove questions that are easy or medium (if any) just to reduce the number of questions
if item["level"] != "hard":
continue
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": dataset_name,
}
elif dataset_name == "simpleqa":
# Handle SimpleQA dataset format
question_data = {
"question": item["problem"],
"answer": item["answer"],
"dataset": dataset_name,
}
else:
question_data = {
"question": item["question"],
"answer": item["true_answer"],
"dataset": dataset_name,
}
all_questions.append(question_data)
print(f"Loaded {len(all_questions)} questions in total")
return all_questions
@weave.op
async def predict(question: str):
return await research_topic(topic=str(question))
@env.task
async def main(datasets: list[str] = ["together-search-bench"], limit: int | None = 1):
questions = await load_questions(datasets)
if limit is not None:
questions = questions[:limit]
print(f"Limited to {len(questions)} question(s)")
evaluation = weave.Evaluation(dataset=questions, scorers=[llm_as_a_judge_scoring])
await evaluation.evaluate(predict)
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(raw_data_path="data").run(main)
CODE24
export HUGGINGFACE_TOKEN=<> # https://huggingface.co/settings/tokens
export WANDB_API_KEY=<> # https://wandb.ai/settings
uv run weave_evals.py
```
The script will run all tasks in the pipeline and log the evaluation results to Weights & Biases.
While you can also evaluate individual tasks, this script focuses on end-to-end evaluation of the end-to-end deep research workflow.

=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/hpo ===
# Hyperparameter optimization
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/ml/optimizer.py).
Hyperparameter Optimization (HPO) is a critical step in the machine learning (ML) lifecycle. Hyperparameters are the knobs and dials of a modelβvalues such as learning rates, tree depths, or dropout rates that significantly impact performance but cannot be learned during training. Instead, we must select them manually or optimize them through guided search.
Model developers often enjoy the flexibility of choosing from a wide variety of model types, whether gradient boosted machines (GBMs), generalized linear models (GLMs), deep learning architectures, or dozens of others. A common challenge across all these options is the need to systematically explore model performance across hyperparameter configurations tailored to the specific dataset and task.
Thankfully, this exploration can be automated. Frameworks like [Optuna](https://optuna.org/), [Hyperopt](https://hyperopt.github.io/hyperopt/), and [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) use advanced sampling algorithms to efficiently search the hyperparameter space and identify optimal configurations. HPO may be executed in two distinct ways:
- **Serial HPO** runs one trial at a time, which is easy to set up but can be painfully slow.
- **Parallel HPO** distributes trials across multiple processes. It typically follows a pattern with two parameters: **_N_**, the total number of trials to run, and **_C_**, the maximum number of trials that can run concurrently. Trials are executed asynchronously, and new ones are scheduled based on the results and status of completed or in-progress ones.
However, parallel HPO introduces a new complexity: the need for a centralized state that tracks:
- All past trials (successes and failures)
- All ongoing trials
This state is essential so that the optimization algorithm can make informed decisions about which hyperparameters to try next.
## A better way to run HPO
This is where Flyte shines.
- There's no need to manage a separate centralized database for state tracking, as every objective run is **cached**, **recorded**, and **recoverable** via Flyte's execution engine.
- The entire HPO process is observable in the UI with full lineage and metadata for each trial.
- Each objective is seeded for reproducibility, enabling deterministic trial results.
- If the main optimization task crashes or is terminated, **Flyte can resume from the last successful or failed trial, making the experiment highly fault-tolerant**.
- Trial functions can be strongly typed, enabling rich, flexible hyperparameter spaces while maintaining strict type safety across trials.
In this example, we combine Flyte with Optuna to optimize a `RandomForestClassifier` on the Iris dataset. Each trial runs in an isolated task, and the optimization process is orchestrated asynchronously, with Flyte handling the underlying scheduling, retries, and caching.
## Declare dependencies
We start by declaring a Python environment using Python 3.13 and specifying our runtime dependencies.
```
# /// script
requires-python = "==3.13"
dependencies = [
"optuna>=4.0.0,<5.0.0",
"flyte>=2.0.0b0",
"scikit-learn==1.7.0",
]
# ///
```
With the environment defined, we begin by importing standard library and third-party modules necessary for both the ML task and distributed execution.
```
import asyncio
import typing
from collections import Counter
from typing import Optional, Union
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
These standard library imports are essential for asynchronous execution (`asyncio`), type annotations (`typing`, `Optional`, `Union`), and aggregating trial state counts (`Counter`).
```
import optuna
from optuna import Trial
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
We use Optuna for hyperparameter optimization and several utilities from scikit-learn to prepare data (`load_iris`), define the model (`RandomForestClassifier`), evaluate it (`cross_val_score`), and shuffle the dataset for randomness (`shuffle`).
```
import flyte
import flyte.errors
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
Flyte is our orchestration framework. We use it to define tasks, manage resources, and recover from execution errors.
## Define the task environment
We define a Flyte task environment called `driver`, which encapsulates metadata, compute resources, the container image context needed for remote execution, and caching behavior.
```
driver = flyte.TaskEnvironment(
name="driver",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="optimizer"),
cache="auto",
)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
This environment specifies that the tasks will run with 1 CPU and 250Mi of memory, the image is built using the current script (`__file__`), and caching is enabled.
## Define the optimizer
Next, we define an `Optimizer` class that handles parallel execution of Optuna trials using async coroutines. This class abstracts the full optimization loop and supports concurrent trial execution with live logging.
```
class Optimizer:
def __init__(
self,
objective: callable,
n_trials: int,
concurrency: int = 1,
delay: float = 0.1,
study: Optional[optuna.Study] = None,
log_delay: float = 0.1,
):
self.n_trials: int = n_trials
self.concurrency: int = concurrency
self.objective: typing.Callable = objective
self.delay: float = delay
self.log_delay = log_delay
self.study = study if study else optuna.create_study()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
We pass the `objective` function, number of trials to run (`n_trials`), and maximum parallel trials (`concurrency`). The optional delay throttles execution between trials, while `log_delay` controls how often logging runs. If no existing Optuna Study is provided, a new one is created automatically.
```
async def log(self):
while True:
await asyncio.sleep(self.log_delay)
counter = Counter()
for trial in self.study.trials:
counter[trial.state.name.lower()] += 1
counts = dict(counter, queued=self.n_trials - len(self))
# print items in dictionary in a readable format
formatted = [f"{name}: {count}" for name, count in counts.items()]
print(f"{' '.join(formatted)}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
This method periodically prints the number of trials in each state (e.g., running, complete, fail). It keeps users informed of ongoing optimization progress and is invoked as a background task when logging is enabled.

_Logs are streamed live as the execution progresses._
```
async def spawn(self, semaphore: asyncio.Semaphore):
async with semaphore:
trial: Trial = self.study.ask()
try:
print("Starting trial", trial.number)
params = {
"n_estimators": trial.suggest_int("n_estimators", 10, 200),
"max_depth": trial.suggest_int("max_depth", 2, 20),
"min_samples_split": trial.suggest_float(
"min_samples_split", 0.1, 1.0
),
}
output = await self.objective(params)
self.study.tell(trial, output, state=optuna.trial.TrialState.COMPLETE)
except flyte.errors.RuntimeUserError as e:
print(f"Trial {trial.number} failed: {e}")
self.study.tell(trial, state=optuna.trial.TrialState.FAIL)
await asyncio.sleep(self.delay)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
Each call to `spawn` runs a single Optuna trial. The `semaphore` ensures that only a fixed number of concurrent trials are active at once, respecting the `concurrency` parameter. We first ask Optuna for a new trial and generate a parameter dictionary by querying the trial object for suggested hyperparameters. The trial is then evaluated by the objective function. If successful, we mark it as `COMPLETE`. If the trial fails due to a `RuntimeUserError` from Flyte, we log and record the failure in the Optuna study.
```
async def __call__(self):
# create semaphore to manage concurrency
semaphore = asyncio.Semaphore(self.concurrency)
# create list of async trials
trials = [self.spawn(semaphore) for _ in range(self.n_trials)]
logger: Optional[asyncio.Task] = None
if self.log_delay:
logger = asyncio.create_task(self.log())
# await all trials to complete
await asyncio.gather(*trials)
if self.log_delay and logger:
logger.cancel()
try:
await logger
except asyncio.CancelledError:
pass
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
The `__call__` method defines the overall async optimization routine. It creates the semaphore, spawns `n_trials` coroutines, and optionally starts the background logging task. All trials are awaited with `asyncio.gather`.
```
def __len__(self) -> int:
"""Return the number of trials in history."""
return len(self.study.trials)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
This method simply allows us to query the number of trials already associated with the study.
## Define the objective function
The objective task defines how we evaluate a particular set of hyperparameters. It's an async task, allowing for caching, tracking, and recoverability across executions.
```
@driver.task
async def objective(params: dict[str, Union[int, float]]) -> float:
data = load_iris()
X, y = shuffle(data.data, data.target, random_state=42)
clf = RandomForestClassifier(
n_estimators=params["n_estimators"],
max_depth=params["max_depth"],
min_samples_split=params["min_samples_split"],
random_state=42,
n_jobs=-1,
)
# Use cross-validation to evaluate performance
score = cross_val_score(clf, X, y, cv=3, scoring="accuracy").mean()
return score.item()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
We use the Iris dataset as a toy classification problem. The input params dictionary contains the trial's hyperparameters, which we unpack into a `RandomForestClassifier`. We shuffle the dataset for randomness, and compute a 3-fold cross-validation accuracy.
## Define the main optimization loop
The optimize task is the main driver of our optimization experiment. It creates the `Optimizer` instance and invokes it.
```
@driver.task
async def optimize(
n_trials: int = 20,
concurrency: int = 5,
delay: float = 0.05,
log_delay: float = 0.1,
) -> dict[str, Union[int, float]]:
optimizer = Optimizer(
objective=objective,
n_trials=n_trials,
concurrency=concurrency,
delay=delay,
log_delay=log_delay,
study=optuna.create_study(
direction="maximize", sampler=optuna.samplers.TPESampler(seed=42)
),
)
await optimizer()
best = optimizer.study.best_trial
print("β Best Trial")
print(" Number :", best.number)
print(" Params :", best.params)
print(" Score :", best.value)
return best.params
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
We configure a `TPESampler` for Optuna and `seed` it for determinism. After running all trials, we extract the best-performing trial and print its parameters and score. Returning the best params allows downstream tasks or clients to use the tuned model.
## Run the experiment
Finally, we include an executable entry point to run this optimization using `flyte.run`.
```
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(optimize, 100, 10)
print(run.url)
run.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*
We load Flyte config from `config.yaml`, launch the optimize task with 100 trials and concurrency of 10, and print a link to view the execution in the Flyte UI.

_Each objective run is cached, recorded, and recoverable. With concurrency set to 10, only 10 trials execute in parallel at any given time._
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations ===
# Integrations
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
Flyte 2 is designed to be extensible by default. While the core platform covers the most common orchestration needs, many production workloads require specialized infrastructure, external services or execution semantics that go beyond the core runtime.
Flyte 2 exposes these capabilities through integrations.
Under the hood, integrations are implemented using Flyte 2's plugin system, which provides a consistent way to extend the platform without modifying core execution logic.
An integration allows you to declaratively enable new capabilities such as distributed compute frameworks or third-party services without manually managing infrastructure. You specify what you need, and Flyte takes care of how it is provisioned, used and cleaned up.
This page covers:
- The types of integrations Flyte 2 supports today
- How integrations fit into Flyte 2's execution model
- How to use integrations in your tasks
- The integrations available out of the box
If you need functionality that doesn't exist yet, Flyte 2's plugin system is intentionally open-ended. You can build and register your own integrations using the same architecture described here.
## Integration categories
Flyte 2 integrations fall into the following categories:
1. **Distributed compute**: Provision transient compute clusters to run tasks across multiple nodes, with automatic lifecycle management.
2. **Agentic AI**: Support for various common aspects of agentic AI applications.
3. **Experiment tracking**: Integrate with experiment tracking platforms for logging metrics, parameters, and artifacts.
4. **Connectors**: Stateless, long-running services that receive execution requests via gRPC and then submit work to external (or internal) systems.
5. **LLM Serving**: Deploy and serve large language models with an OpenAI-compatible API.
## Distributed compute
Distributed compute integrations allow tasks to run on dynamically provisioned clusters. These clusters are created just-in-time, scoped to the task execution and torn down automatically when the task completes.
This enables large-scale parallelism without requiring users to operate or maintain long-running infrastructure.
### Supported distributed compute integrations
| Plugin | Description | Common use cases |
| -------------------- | ------------------------------------------------ | ------------------------------------------------------ |
| [Ray](./ray/_index) | Provisions Ray clusters via KubeRay | Distributed Python, ML training, hyperparameter tuning |
| [Spark](./spark/_index) | Provisions Spark clusters via Spark Operator | Large-scale data processing, ETL pipelines |
| [Dask](./dask/_index) | Provisions Dask clusters via Dask Operator | Parallel Python workloads, dataframe operations |
| [PyTorch](./pytorch/_index) | Distributed PyTorch training with elastic launch | Single-node and multi-node training |
Each plugin encapsulates:
- Cluster provisioning
- Resource configuration
- Networking and service discovery
- Lifecycle management and teardown
From the task author's perspective, these details are abstracted away.
### How the plugin system works
At a high level, Flyte 2's distributed compute plugin architecture follows a simple and consistent pattern.
#### 1. Registration
Each plugin registers itself with Flyte 2's core plugin registry:
- **`TaskPluginRegistry`**: The central registry for all distributed compute plugins
- Each plugin declares:
- Its configuration schema
- How that configuration maps to execution behavior
This registration step makes the plugin discoverable by the runtime.
#### 2. Task environments and plugin configuration
Integrations are activated through a `TaskEnvironment`.
A `TaskEnvironment` bundles:
- A container image
- Execution settings
- A plugin configuration object enabled with `plugin_config`
The plugin configuration describes _what_ infrastructure or integration the task requires.
#### 3. Automatic provisioning and execution
When a task associated with a `TaskEnvironment` runs:
1. Flyte inspects the environment's plugin configuration
2. The plugin provisions the required infrastructure or integration
3. The task executes with access to that capability
4. Flyte cleans up all transient resources after completion
### Example: Using the Dask plugin
Below is a complete example showing how a task gains access to a Dask cluster simply by running inside an environment configured with the Dask plugin.
```python
from flyteplugins.dask import Dask, WorkerGroup
import flyte
# Define the Dask cluster configuration
dask_config = Dask(
workers=WorkerGroup(number_of_workers=4)
)
# Create a task environment that enables Dask
env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
)
# Any task in this environment has access to the Dask cluster
@env.task
async def process_data(data: list) -> list:
from distributed import Client
client = Client() # Automatically connects to the provisioned cluster
futures = client.map(transform, data)
return client.gather(futures)
```
When `process_data` executes, Flyte performs the following steps:
1. Provisions a Dask cluster with 4 workers
2. Executes the task with network access to the cluster
3. Tears down the cluster once the task completes
No cluster management logic appears in the task code. The task only expresses intent.
### Key design principle
All distributed compute integrations follow the same mental model:
- You declare the required capability via configuration
- You attach that configuration to a task environment
- Tasks decorated with that environment automatically gain access to the capability
This makes it easy to swap execution backends or introduce distributed compute incrementally without rewriting workflows.
## Agentic AI
Agentic AI integrations provide drop-in replacements for LLM provider SDKs. They let you use Flyte tasks as agent tools so that tool calls run with full Flyte observability, retries, and caching.
### Supported agentic AI integrations
| Plugin | Description | Common use cases |
| ----------------------------------- | ------------------------------------------------------------ | ------------------------------------ |
| [OpenAI](./openai/_index) | Drop-in replacement for OpenAI Agents SDK `function_tool` | Agentic workflows with OpenAI models |
| [Anthropic](./anthropic/_index) | Agent loop and `function_tool` for the Anthropic Claude SDK | Agentic workflows with Claude |
| [Gemini](./gemini/_index) | Agent loop and `function_tool` for the Google Gemini SDK | Agentic workflows with Gemini |
| [Code generation](./codegen/_index) | LLM-driven code generation with automatic testing in sandboxes | Data processing, ETL, analysis pipelines |
## Experiment tracking
Experiment tracking integrations let you log metrics, parameters, and artifacts to external tracking platforms during Flyte task execution.
### Supported experiment tracking integrations
| Plugin | Description | Common use cases |
| ------------------------------------ | ---------------------------- | --------------------------------------------- |
| [MLflow](./mlflow/_index) | MLflow experiment tracking | Experiment tracking, autologging, model registry |
| [Weights and Biases](./wandb/_index) | Weights & Biases integration | Experiment tracking and hyperparameter tuning |
## Connectors
Connectors are stateless, longβrunning services that receive execution requests via gRPC and then submit work to external (or internal) systems. Each connector runs as its own Kubernetes deployment, and is triggered when a Flyte task of the matching type is executed.
Although they normally run inside the control plane, you can also run connectors locally as long as the required secrets/credentials are present locally. This is useful because connectors are just Python services that can be spawned inβprocess.
Connectors are designed to scale horizontally and reduce load on the core Flyte backend because they execute _outside_ the core system. This decoupling makes connectors efficient, resilient, and easy to iterate on. You can even test them locally without modifying backend configuration, which reduces friction during development.
### Supported connectors
| Connector | Description | Common use cases |
| ---------------------------------- | ---------------------------------------------- | ---------------------------------------- |
| [Snowflake](./snowflake/_index) | Run SQL queries on Snowflake asynchronously | Data warehousing, ETL, analytics queries |
| [BigQuery](./bigquery/_index) | Run SQL queries on Google BigQuery | Data warehousing, ETL, analytics queries |
| [Databricks](./databricks/_index) | Run PySpark jobs on Databricks clusters | Large-scale data processing, Spark ETL |
### Creating a new connector
If none of the existing connectors meet your needs, you can build your own.
> [!NOTE]
> Connectors communicate via Protobuf, so in theory they can be implemented in any language.
> Today, only **Python** connectors are supported.
### Async connector interface
To implement a new async connector, extend `AsyncConnector` and implement the following methods, all of which must be idempotent:
| Method | Purpose |
| -------- | ----------------------------------------------------------- |
| `create` | Launch the external job (via REST, gRPC, SDK, or other API) |
| `get` | Fetch current job state (return job status or output) |
| `delete` | Delete / cancel the external job |
To test the connector locally, the connector task should inherit from
[AsyncConnectorExecutorMixin](https://github.com/flyteorg/flyte-sdk/blob/1d49299294cd5e15385fe8c48089b3454b7a4cd1/src/flyte/connectors/_connector.py#L206). This mixin simulates how the Flyte 2 system executes asynchronous connector tasks, making it easier to validate your connector implementation before deploying it.
### Example: Model training connector
The following example implements a connector that launches a model training job on an external training service.
```python
import typing
from dataclasses import dataclass
import httpx
from flyte.connectors import AsyncConnector, Resource, ResourceMeta
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl2.core.tasks_pb2 import TaskTemplate
from google.protobuf import json_format
@dataclass
class ModelTrainJobMeta(ResourceMeta):
job_id: str
endpoint: str
class ModelTrainingConnector(AsyncConnector):
"""
Example connector that launches a ML model training job on an external training service.
POST β launch training job
GET β poll training progress
DELETE β cancel training job
"""
name = "Model Training Connector"
task_type_name = "external_model_training"
metadata_type = ModelTrainJobMeta
async def create(
self,
task_template: TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
**kwargs,
) -> ModelTrainJobMeta:
"""
Submit training job via POST.
Response returns job_id we later use in get().
"""
custom = json_format.MessageToDict(task_template.custom) if task_template.custom else None
async with httpx.AsyncClient() as client:
r = await client.post(
custom["endpoint"],
json={"dataset_uri": inputs["dataset_uri"], "epochs": inputs["epochs"]},
)
r.raise_for_status()
return ModelTrainJobMeta(job_id=r.json()["job_id"], endpoint=custom["endpoint"])
async def get(self, resource_meta: ModelTrainJobMeta, **kwargs) -> Resource:
"""
Poll external API until training job finishes.
Must be safe to call repeatedly.
"""
async with httpx.AsyncClient() as client:
r = await client.get(f"{resource_meta.endpoint}/{resource_meta.job_id}")
data = r.json()
if data["status"] == "finished":
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="training-dashboard", uri=f"https://example-mltrain.com/train/{resource_meta.job_id}")],
outputs={"results": data["results"]},
)
return Resource(phase=TaskExecution.RUNNING)
async def delete(self, resource_meta: ModelTrainJobMeta, **kwargs):
"""
Optionally call DELETE on external API.
Safe even if job already completed.
"""
async with httpx.AsyncClient() as client:
await client.delete(f"{resource_meta.endpoint}/{resource_meta.job_id}")
```
To use this connector, you should define a task whose `task_type` matches the connector.
```python
import flyte.io
from typing import Any, Dict, Optional
from flyte.extend import TaskTemplate
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.models import NativeInterface, SerializationContext
class ModelTrainTask(AsyncConnectorExecutorMixin, TaskTemplate):
_TASK_TYPE = "external_model_training"
def __init__(
self,
name: str,
endpoint: str,
**kwargs,
):
super().__init__(
name=name,
interface=NativeInterface(
inputs={"epochs": int, "dataset_uri": str},
outputs={"results": flyte.io.File},
),
task_type=self._TASK_TYPE,
**kwargs,
)
self.endpoint = endpoint
def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
return {"endpoint": self.endpoint}
```
Here is an example of how to use the `ModelTrainTask`:
```python
import flyte
from flyteplugins.model_training import ModelTrainTask
model_train_task = ModelTrainTask(
name="model_train",
endpoint="https://example-mltrain.com",
)
model_train_env = flyte.TaskEnvironment.from_task("model_train_env", model_train_task)
env = flyte.TaskEnvironment(
name="hello_world",
resources=flyte.Resources(memory="250Mi"),
image=flyte.Image.from_debian_base(name="model_training").with_pip_packages(
"flyteplugins-model-training", pre=True
),
depends_on=[model_train_env],
)
@env.task
def data_prep() -> str:
return "gs://my-bucket/dataset.csv"
@env.task
def train_model(epochs: int) -> flyte.io.File:
dataset_uri = data_prep()
return model_train_task(epochs=epochs, dataset_uri=dataset_uri)
```
### Build a custom connector image
Build a custom image when you're ready to deploy your connector to your cluster.
To build the Docker image for your connector, run the following script:
```python
import asyncio
from flyte import Image
from flyte.extend import ImageBuildEngine
async def build_flyte_connector_bigquery_image(registry: str, name: str, builder: str = "local"):
"""
Build the SDK default connector image optionally overriding
the container registry and image name.
Args:
registry: e.g. "ghcr.io/my-org" or "123456789012.dkr.ecr.us-west-2.amazonaws.com".
name: e.g. "my-connector".
builder: e.g. "local" or "remote".
"""
default_image = Image.from_debian_base(
registry=registry, name=name
).with_pip_packages("flyteintegrations-bigquery", pre=True)
await ImageBuildEngine.build(default_image, builder=builder)
if __name__ == "__main__":
print("Building connector image...")
asyncio.run(
build_flyte_connector_bigquery_image(
registry="", name="flyte-bigquery", builder="local"
)
)
```
## LLM Serving
LLM serving integrations let you deploy and serve large language models as Flyte apps with an OpenAI-compatible API. They handle model loading, GPU management, and autoscaling.
### Supported LLM serving integrations
| Plugin | Description | Common use cases |
| ----------------------------------------------------------------- | ----------------------------------------------------- | ------------------------------------ |
| **Build apps > SGLang app** | Deploy models with SGLang's high-throughput runtime | LLM inference, model serving |
| **Build apps > vLLM app** | Deploy models with vLLM's PagedAttention engine | LLM inference, model serving |
For full setup instructions including multi-GPU deployment, model prefetching, and autoscaling, see the **Build apps > SGLang app** and **Build apps > vLLM app** pages.
## Subpages
- **Anthropic**
- **BigQuery**
- **Dask**
- **Databricks**
- **Gemini**
- **OpenAI**
- **PyTorch**
- **Ray**
- **Snowflake**
- **Spark**
- **Weights & Biases**
- **Code generation**
- **MLflow**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/anthropic ===
# Anthropic
The Anthropic plugin lets you build agentic workflows with [Claude](https://www.anthropic.com/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Claude can call, and a `run_agent` function that drives the agent conversation loop.
When Claude calls a tool, the call executes as a Flyte task with full observability, retries, and caching.
## Installation
```bash
pip install flyteplugins-anthropic
```
Requires `anthropic >= 0.40.0`.
## Quick start
```python
import flyte
from flyteplugins.anthropic import function_tool, run_agent
env = flyte.TaskEnvironment(
name="claude-agent",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="anthropic_agent"),
secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY"),
)
@function_tool
@env.task
async def get_weather(city: str) -> str:
"""Get the current weather for a city."""
return f"The weather in {city} is sunny, 72F"
@env.task
async def main(prompt: str) -> str:
tools = [get_weather]
return await run_agent(prompt=prompt, tools=tools)
```
## API
### `function_tool`
Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Claude can invoke.
```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
"""Tool description sent to Claude."""
...
```
Can also be called with optional overrides:
```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
...
```
Parameters:
| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |
> [!NOTE]
> The docstring on each `@function_tool` task is sent to Claude as the tool description. Write clear, concise docstrings.
### `Agent`
A dataclass for bundling agent configuration:
```python
from flyteplugins.anthropic import Agent
agent = Agent(
name="my-agent",
instructions="You are a helpful assistant.",
model="claude-sonnet-4-20250514",
tools=[get_weather],
max_tokens=4096,
max_iterations=10,
)
```
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |
### `run_agent`
Runs a Claude conversation loop, dispatching tool calls to Flyte tasks until Claude returns a final response.
```python
result = await run_agent(
prompt="What's the weather in Tokyo?",
tools=[get_weather],
model="claude-sonnet-4-20250514",
)
```
You can also pass an `Agent` object:
```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `system` | `str` | `None` | System prompt |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `ANTHROPIC_API_KEY` env var) |
## Secrets
Store your Anthropic API key as a Flyte secret and expose it as an environment variable:
```python
secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY")
```
## API reference
See the [Anthropic API reference](../../api-reference/integrations/anthropic/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/bigquery ===
# BigQuery
The BigQuery connector lets you run SQL queries against [Google BigQuery](https://cloud.google.com/bigquery) directly from Flyte tasks. Queries are submitted asynchronously via the BigQuery Jobs API and polled for completion, so they don't block a worker while waiting for results.
The connector supports:
- Parameterized SQL queries with typed inputs
- Google Cloud service account authentication
- Returns query results as DataFrames
- Query cancellation on task abort
## Installation
```bash
pip install flyteplugins-bigquery
```
This installs the Google Cloud BigQuery client libraries.
## Quick start
Here's a minimal example that runs a SQL query on BigQuery:
```python
from flyte.io import DataFrame
from flyteplugins.bigquery import BigQueryConfig, BigQueryTask
config = BigQueryConfig(
ProjectID="my-gcp-project",
Location="US",
)
count_users = BigQueryTask(
name="count_users",
query_template="SELECT COUNT(*) FROM dataset.users",
plugin_config=config,
output_dataframe_type=DataFrame,
)
```
This defines a task called `count_users` that runs the query on the configured BigQuery instance. When executed, the connector:
1. Connects to BigQuery using the provided configuration
2. Submits the query asynchronously via the Jobs API
3. Polls until the query completes or fails
To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:
```python
import flyte
bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", count_users)
if __name__ == "__main__":
flyte.init_from_config()
# Run locally (connector runs in-process, requires credentials locally)
run = flyte.with_runcontext(mode="local").run(count_users)
# Run remotely (connector runs on the control plane)
run = flyte.with_runcontext(mode="remote").run(count_users)
print(run.url)
```
> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. BigQuery tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-bigquery` and credentials to be available on your machine.
## Configuration
### `BigQueryConfig` parameters
| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `ProjectID` | `str` | Yes | GCP project ID |
| `Location` | `str` | No | BigQuery region (e.g., `"US"`, `"EU"`) |
| `QueryJobConfig` | `bigquery.QueryJobConfig` | No | Native BigQuery [QueryJobConfig](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJobConfig) object for advanced settings |
### `BigQueryTask` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `name` | `str` | Unique task name |
| `query_template` | `str` | SQL query (whitespace is normalized before execution) |
| `plugin_config` | `BigQueryConfig` | Connection configuration |
| `inputs` | `Dict[str, Type]` | Named typed inputs bound as query parameters |
| `output_dataframe_type` | `Type[DataFrame]` | If set, query results are returned as a `DataFrame` |
| `google_application_credentials` | `str` | Name of the Flyte secret containing the GCP service account JSON key |
## Authentication
Pass the name of a Flyte secret containing your GCP service account JSON key:
```python
query = BigQueryTask(
name="secure_query",
query_template="SELECT * FROM dataset.sensitive_data",
plugin_config=config,
google_application_credentials="my-gcp-sa-key",
)
```
## Query templating
Use the `inputs` parameter to define typed inputs for your query. Input values are bound as BigQuery `ScalarQueryParameter` values.
### Supported input types
| Python type | BigQuery type |
|-------------|---------------|
| `int` | `INT64` |
| `float` | `FLOAT64` |
| `str` | `STRING` |
| `bool` | `BOOL` |
| `bytes` | `BYTES` |
| `datetime` | `DATETIME` |
| `list` | `ARRAY` |
### Parameterized query example
```python
from flyte.io import DataFrame
events_by_region = BigQueryTask(
name="events_by_region",
query_template="SELECT * FROM dataset.events WHERE region = @region AND score > @min_score",
plugin_config=config,
inputs={"region": str, "min_score": float},
output_dataframe_type=DataFrame,
)
```
> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.
## Retrieving query results
Set `output_dataframe_type` to capture results as a DataFrame:
```python
from flyte.io import DataFrame
top_customers = BigQueryTask(
name="top_customers",
query_template="""
SELECT customer_id, SUM(amount) AS total_spend
FROM dataset.orders
GROUP BY customer_id
ORDER BY total_spend DESC
LIMIT 100
""",
plugin_config=config,
output_dataframe_type=DataFrame,
)
```
If you don't need query results (for example, DDL statements or INSERT queries), omit `output_dataframe_type`.
## API reference
See the [BigQuery API reference](../../api-reference/integrations/bigquery/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/dask ===
# Dask
The Dask plugin lets you run [Dask](https://www.dask.org/) jobs natively on Kubernetes. Flyte provisions a transient Dask cluster for each task execution using the [Dask Kubernetes Operator](https://kubernetes.dask.org/en/latest/operator.html) and tears it down on completion.
## When to use this plugin
- Parallel Python workloads that outgrow a single machine
- Distributed DataFrame operations on large datasets
- Workloads that use Dask's task scheduler for arbitrary computation graphs
- Jobs that need to scale NumPy, pandas, or scikit-learn workflows across multiple nodes
## Installation
```bash
pip install flyteplugins-dask
```
Your task image must also include the Dask distributed scheduler:
```python
image = flyte.Image.from_debian_base(name="dask").with_pip_packages("flyteplugins-dask")
```
## Configuration
Create a `Dask` configuration and pass it as `plugin_config` to a `TaskEnvironment`:
```python
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
)
```
### `Dask` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `scheduler` | `Scheduler` | Scheduler pod configuration (defaults to `Scheduler()`) |
| `workers` | `WorkerGroup` | Worker group configuration (defaults to `WorkerGroup()`) |
### `Scheduler` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `image` | `str` | Custom scheduler image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests for the scheduler pod |
### `WorkerGroup` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `number_of_workers` | `int` | Number of worker pods (default: `1`) |
| `image` | `str` | Custom worker image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests per worker pod |
> [!NOTE]
> The scheduler and all workers should use the same Python environment to avoid serialization issues.
### Accessing the Dask client
Inside a Dask task, create a `distributed.Client()` with no arguments. It automatically connects to the provisioned cluster:
```python
from distributed import Client
@dask_env.task
async def my_dask_task(n: int) -> list:
client = Client()
futures = client.map(lambda x: x + 1, range(n))
return client.gather(futures)
```
## Example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-dask",
# "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///
import asyncio
import typing
from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
import flyte.remote
import flyte.storage
from flyte import Resources
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
task_env = flyte.TaskEnvironment(
name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
resources=Resources(cpu="1", memory="1Gi"),
depends_on=[task_env],
)
@task_env.task()
async def hello_dask():
await asyncio.sleep(5)
print("Hello from the Dask task!")
@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
print("running dask task")
t = asyncio.create_task(hello_dask())
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_dask_nested)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/dask/dask_example.py*
## API reference
See the [Dask API reference](../../api-reference/integrations/dask/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/databricks ===
# Databricks
The Databricks plugin lets you run PySpark jobs on [Databricks](https://www.databricks.com/) clusters directly from Flyte tasks. You write normal PySpark code in a Flyte task, and the plugin submits it to Databricks via the [Jobs API 2.1](https://docs.databricks.com/api/workspace/jobs/submit). The connector handles job submission, polling, and cancellation.
The plugin supports:
- Running PySpark tasks on new or existing Databricks clusters
- Full Spark configuration (driver/executor memory, cores, instances)
- Databricks cluster auto-scaling
- API token-based authentication
## Installation
```bash
pip install flyteplugins-databricks
```
This also installs `flyteplugins-spark` as a dependency, since the Databricks plugin extends the Spark plugin.
## Quick start
Create a `Databricks` configuration and pass it as `plugin_config` to a `TaskEnvironment`:
```python
from flyteplugins.databricks import Databricks
import flyte
image = (
flyte.Image.from_base("databricksruntime/standard:16.4-LTS")
.clone(name="spark", registry="ghcr.io/flyteorg", extendable=True)
.with_env_vars({"UV_PYTHON": "/databricks/python3/bin/python"})
.with_pip_packages("flyteplugins-databricks", pre=True)
)
databricks_conf = Databricks(
spark_conf={
"spark.driver.memory": "2000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
executor_path="/databricks/python3/bin/python",
databricks_conf={
"run_name": "flyte databricks plugin",
"new_cluster": {
"spark_version": "13.3.x-scala2.12",
"node_type_id": "m6i.large",
"autoscale": {"min_workers": 1, "max_workers": 2},
},
"timeout_seconds": 3600,
"max_retries": 1,
},
databricks_instance="myaccount.cloud.databricks.com",
databricks_token="DATABRICKS_TOKEN",
)
databricks_env = flyte.TaskEnvironment(
name="databricks_env",
resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
plugin_config=databricks_conf,
image=image,
)
```
Then use the environment to decorate your task:
```python
@databricks_env.task
async def hello_databricks() -> float:
spark = flyte.ctx().data["spark_session"]
# Use spark as a normal SparkSession
count = spark.sparkContext.parallelize(range(100)).count()
return float(count)
```
## Configuration
The `Databricks` config extends the [Spark](../spark/_index) config with Databricks-specific fields.
### Spark fields (inherited)
| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs |
| `executor_path` | `str` | Path to the Python binary on the Databricks cluster (e.g., `/databricks/python3/bin/python`) |
| `applications_path` | `str` | Path to the main application file |
### Databricks-specific fields
| Parameter | Type | Description |
|-----------|------|-------------|
| `databricks_conf` | `Dict[str, Union[str, dict]]` | Databricks [run-submit](https://docs.databricks.com/api/workspace/jobs/submit) job configuration. Must contain either `existing_cluster_id` or `new_cluster` |
| `databricks_instance` | `str` | Your workspace domain (e.g., `myaccount.cloud.databricks.com`). Can also be set via the `FLYTE_DATABRICKS_INSTANCE` env var on the connector |
| `databricks_token` | `str` | Name of the Flyte secret containing the Databricks API token |
### `databricks_conf` structure
The `databricks_conf` dict maps to the Databricks run-submit API payload. Key fields:
| Field | Description |
|-------|-------------|
| `new_cluster` | Cluster spec with `spark_version`, `node_type_id`, `autoscale`, etc. |
| `existing_cluster_id` | ID of an existing cluster to use instead of creating a new one |
| `run_name` | Display name in the Databricks UI |
| `timeout_seconds` | Maximum job duration |
| `max_retries` | Number of retries before marking the job as failed |
The connector automatically injects the Docker image, Spark configuration, and environment variables from the task container into the cluster spec.
## Authentication
Store your Databricks API token as a Flyte secret. The `databricks_token` parameter specifies the secret name:
```python
databricks_conf = Databricks(
# ...
databricks_token="DATABRICKS_TOKEN",
)
```
## Accessing the Spark session
Inside a Databricks task, the `SparkSession` is available through the task context, just like the [Spark plugin](../spark/_index):
```python
@databricks_env.task
async def my_databricks_task() -> float:
spark = flyte.ctx().data["spark_session"]
df = spark.read.parquet("s3://my-bucket/data.parquet")
return float(df.count())
```
## API reference
See the [Databricks API reference](../../api-reference/integrations/databricks/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/gemini ===
# Gemini
The Gemini plugin lets you build agentic workflows with [Gemini](https://ai.google.dev/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Gemini can call, and a `run_agent` function that drives the agent conversation loop.
When Gemini calls a tool, the call executes as a Flyte task with full observability, retries, and caching. Gemini's native parallel function calling is supported: multiple tool calls in a single turn are all dispatched and their results bundled into one response.
## Installation
```bash
pip install flyteplugins-gemini
```
Requires `google-genai >= 1.0.0`.
## Quick start
```python
import flyte
from flyteplugins.gemini import function_tool, run_agent
env = flyte.TaskEnvironment(
name="gemini-agent",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="gemini_agent"),
secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY"),
)
@function_tool
@env.task
async def get_weather(city: str) -> str:
"""Get the current weather for a city."""
return f"The weather in {city} is sunny, 72F"
@env.task
async def main(prompt: str) -> str:
tools = [get_weather]
return await run_agent(prompt=prompt, tools=tools)
```
## API
### `function_tool`
Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Gemini can invoke.
```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
"""Tool description sent to Gemini."""
...
```
Can also be called with optional overrides:
```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
...
```
Parameters:
| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |
> [!NOTE]
> The docstring on each `@function_tool` task is sent to Gemini as the tool description. Write clear, concise docstrings.
### `Agent`
A dataclass for bundling agent configuration:
```python
from flyteplugins.gemini import Agent
agent = Agent(
name="my-agent",
instructions="You are a helpful assistant.",
model="gemini-2.5-flash",
tools=[get_weather],
max_output_tokens=8192,
max_iterations=10,
)
```
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |
### `run_agent`
Runs a Gemini conversation loop, dispatching tool calls to Flyte tasks until Gemini returns a final response.
```python
result = await run_agent(
prompt="What's the weather in Tokyo?",
tools=[get_weather],
model="gemini-2.5-flash",
)
```
You can also pass an `Agent` object:
```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `system` | `str` | `None` | System prompt |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `GOOGLE_API_KEY` env var) |
## Secrets
Store your Google API key as a Flyte secret and expose it as an environment variable:
```python
secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY")
```
## API reference
See the [Gemini API reference](../../api-reference/integrations/gemini/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/openai ===
# OpenAI
The OpenAI plugin provides a drop-in replacement for the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) `function_tool` decorator. It lets you use Flyte tasks as tools in agentic workflows so that tool calls run as tracked, reproducible Flyte task executions.
## When to use this plugin
- Building agentic workflows with the OpenAI Agents SDK on Flyte
- You want tool calls to run as Flyte tasks with full observability, retries, and caching
- You want to combine LLM agents with existing Flyte pipelines
## Installation
```bash
pip install flyteplugins-openai
```
Requires `openai-agents >= 0.2.4`.
## Usage
The plugin provides a single decorator, `function_tool`, that wraps Flyte tasks as OpenAI agent tools.
### `function_tool`
When applied to a Flyte task (a function decorated with `@env.task`), `function_tool` makes that task available as an OpenAI `FunctionTool`. The agent can call it like any other tool, and the call executes as a Flyte task.
When applied to a regular function or a `@flyte.trace`-decorated function, it delegates directly to the OpenAI Agents SDK's built-in `function_tool`.
### Basic pattern
1. Define a `TaskEnvironment` with your image and secrets
2. Decorate your task functions with `@function_tool` and `@env.task`
3. Pass the tools to an `Agent`
4. Run the agent from another Flyte task
```python
from agents import Agent, Runner
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny")
agent = Agent(
name="Weather Agent",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
return result.final_output
```
> [!NOTE]
> The docstring on each `@function_tool` task is sent to the LLM as the tool description. Write clear, concise docstrings that describe what the tool does and what its parameters mean.
### Secrets
Store your OpenAI API key as a Flyte secret and expose it as an environment variable:
```python
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY")
```
## Example
```python
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
## API reference
See the [OpenAI API reference](../../api-reference/integrations/openai/_index) for full details.
## Subpages
- **OpenAI > Agent tools**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/openai/agent_tools ===
# Agent tools
In this example, we will use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
This example is based on the [basic tools example](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tools.py) example from the `openai-agents-python` repo.
First, create an OpenAI API key, which you can get from the [OpenAI website](https://platform.openai.com/account/api-keys).
Then, create a secret on your Flyte cluster with:
```
flyte create secret OPENAI_API_KEY --value
```
Then, we'll use `uv script` to specify our dependencies.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
Next, we'll import the libraries and create a `TaskEnvironment`, which we need to run the example:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
## Define the tools
We'll define a tool that can get weather information for a
given city. In this case, we'll use a toy function that returns a hard-coded `Weather` object.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
In this code snippet, the `@function_tool` decorator is imported from `flyteplugins.openai.agents`, which is a drop-in replacement for the `@function_tool` decorator from `openai-agents` library.
## Define the agent
Then, we'll define the agent, which calls the tool:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
## Run the agent
Finally, we'll run the agent. Create `config.yaml` file, which the `flyte.init_from_config()` function will use to connect to
the Flyte cluster:
```bash
flyte create config \
--output ~/.flyte/config.yaml \
--endpoint demo.hosted.unionai.cloud/ \
--project flytesnacks \
--domain development \
--builder remote
```
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*
## Conclusion
In this example, we've seen how to use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
The full code is available [here](https://github.com/unionai/unionai-examples/tree/main/v2/integrations/flyte-plugins/openai/openai).
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/pytorch ===
# PyTorch
The PyTorch plugin lets you run distributed [PyTorch](https://pytorch.org/) training jobs natively on Kubernetes. It uses the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) to manage multi-node training with PyTorch's elastic launch (`torchrun`).
## When to use this plugin
- Single-node or multi-node distributed training with `DistributedDataParallel` (DDP)
- Elastic training that can scale up and down during execution
- Any workload that uses `torch.distributed` for data-parallel or model-parallel training
## Installation
```bash
pip install flyteplugins-pytorch
```
## Configuration
Create an `Elastic` configuration and pass it as `plugin_config` to a `TaskEnvironment`:
```python
from flyteplugins.pytorch import Elastic
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nnodes=2,
nproc_per_node=1,
),
image=image,
)
```
### `Elastic` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `nnodes` | `int` or `str` | **Required.** Number of nodes. Use an int for a fixed count or a range string (e.g., `"2:4"`) for elastic training |
| `nproc_per_node` | `int` | **Required.** Number of processes (workers) per node |
| `rdzv_backend` | `str` | Rendezvous backend: `"c10d"` (default), `"etcd"`, or `"etcd-v2"` |
| `max_restarts` | `int` | Maximum worker group restarts (default: `3`) |
| `monitor_interval` | `int` | Agent health check interval in seconds (default: `3`) |
| `run_policy` | `RunPolicy` | Job run policy (cleanup, TTL, deadlines, retries) |
### `RunPolicy` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `clean_pod_policy` | `str` | Pod cleanup policy: `"None"`, `"all"`, or `"Running"` |
| `ttl_seconds_after_finished` | `int` | Seconds to keep pods after job completion |
| `active_deadline_seconds` | `int` | Maximum time the job can run (seconds) |
| `backoff_limit` | `int` | Number of retries before marking the job as failed |
### NCCL tuning parameters
The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nccl_heartbeat_timeout_sec` | `int` | `300` | NCCL heartbeat timeout (seconds) |
| `nccl_async_error_handling` | `bool` | `False` | Enable async NCCL error handling |
| `nccl_collective_timeout_sec` | `int` | `None` | Timeout for NCCL collective operations |
| `nccl_enable_monitoring` | `bool` | `True` | Enable NCCL monitoring |
### Writing a distributed training task
Tasks using this plugin do not need to be `async`. Initialize the process group and use `DistributedDataParallel` as you normally would with `torchrun`:
```python
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP
@torch_env.task
def train(epochs: int) -> float:
torch.distributed.init_process_group("gloo")
model = DDP(MyModel())
# ... training loop ...
return final_loss
```
> [!NOTE]
> When `nnodes=1`, the task runs as a regular Python task (no Kubernetes training job is created). Set `nnodes >= 2` for multi-node distributed training.
## Example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-pytorch",
# "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///
import typing
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import flyte
image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
image=image,
)
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
"""
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
dataset = TensorDataset(x_train, y_train)
# Distributed-aware sampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
def train_loop(epochs: int = 3) -> float:
"""
A simple training loop for linear regression.
"""
torch.distributed.init_process_group("gloo")
model = DDP(LinearRegressionModel())
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
dataloader = prepare_dataloader(
rank=rank,
world_size=world_size,
batch_size=64,
)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
final_loss = 0.0
for _ in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")
return final_loss
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
A nested task that sets up a simple distributed training job using PyTorch's
"""
print("starting launcher")
loss = train_loop(epochs=epochs)
print("Training complete")
return loss
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(torch_distributed_train, epochs=3)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py*
## API reference
See the [PyTorch API reference](../../api-reference/integrations/pytorch/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/ray ===
# Ray
The Ray plugin lets you run [Ray](https://www.ray.io/) jobs natively on Kubernetes. Flyte provisions a transient Ray cluster for each task execution using [KubeRay](https://github.com/ray-project/kuberay) and tears it down on completion.
## When to use this plugin
- Distributed Python workloads (parallel computation, data processing)
- ML training with Ray Train or hyperparameter tuning with Ray Tune
- Ray Serve inference workloads
- Any workload that benefits from Ray's actor model or task parallelism
## Installation
```bash
pip install flyteplugins-ray
```
Your task image must also include a compatible version of Ray:
```python
image = (
flyte.Image.from_debian_base(name="ray")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
```
## Configuration
Create a `RayJobConfig` and pass it as `plugin_config` to a `TaskEnvironment`:
```python
from flyteplugins.ray import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
)
```
### `RayJobConfig` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `worker_node_config` | `List[WorkerNodeConfig]` | **Required.** List of worker group configurations |
| `head_node_config` | `HeadNodeConfig` | Head node configuration (optional) |
| `enable_autoscaling` | `bool` | Enable Ray autoscaler (default: `False`) |
| `runtime_env` | `dict` | Ray runtime environment (pip packages, env vars, etc.) |
| `address` | `str` | Connect to an existing Ray cluster instead of provisioning one |
| `shutdown_after_job_finishes` | `bool` | Shut down the cluster after the job completes (default: `False`) |
| `ttl_seconds_after_finished` | `int` | Seconds to keep the cluster after completion before cleanup |
### `WorkerNodeConfig` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `group_name` | `str` | **Required.** Name of this worker group |
| `replicas` | `int` | **Required.** Number of worker replicas |
| `min_replicas` | `int` | Minimum replicas (for autoscaling) |
| `max_replicas` | `int` | Maximum replicas (for autoscaling) |
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for workers |
| `requests` | `Resources` | Resource requests per worker |
| `limits` | `Resources` | Resource limits per worker |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |
### `HeadNodeConfig` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for the head node |
| `requests` | `Resources` | Resource requests for the head node |
| `limits` | `Resources` | Resource limits for the head node |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |
### Connecting to an existing cluster
To connect to an existing Ray cluster instead of provisioning a new one, set the `address` parameter:
```python
ray_config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
address="ray://existing-cluster:10001",
)
```
## Examples
The following example shows how to configure Ray in a `TaskEnvironment`. Flyte automatically provisions a Ray cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///
import asyncio
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.remote
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)
task_env = flyte.TaskEnvironment(
name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray():
await asyncio.sleep(20)
print("Hello from the Ray task!")
@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
print("running ray task")
t = asyncio.create_task(hello_ray())
futures = [f.remote(i) for i in range(n)]
res = ray.get(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_ray_nested)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_example.py*
The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///
import os
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=3600,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
task_env = flyte.TaskEnvironment(
name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_cluster",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
"""
Run a simple Ray task that connects to an existing Ray cluster.
"""
ray.init(address=f"ray://{cluster_ip}:10001")
futures = [f.remote(i) for i in range(5)]
res = ray.get(futures)
return res
@ray_env.task
async def create_ray_cluster() -> str:
"""
Create a Ray cluster and return the head node IP address.
"""
print("creating ray cluster")
cluster_ip = os.getenv("MY_POD_IP")
if cluster_ip is None:
raise ValueError("MY_POD_IP environment variable is not set")
return f"{cluster_ip}"
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(create_ray_cluster)
run.wait()
print("run url:", run.url)
print("cluster created, running ray task")
print("ray address:", run.outputs()[0])
run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
print("run url:", run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_existing_example.py*
## API reference
See the [Ray API reference](../../api-reference/integrations/ray/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/snowflake ===
# Snowflake
The Snowflake connector lets you run SQL queries against [Snowflake](https://www.snowflake.com/) directly from Flyte tasks. Queries are submitted asynchronously and polled for completion, so they don't block a worker while waiting for results.
The connector supports:
- Parameterized SQL queries with typed inputs
- Key-pair and password-based authentication
- Returns query results as DataFrames
- Automatic links to the Snowflake query dashboard in the Flyte UI
- Query cancellation on task abort
## Installation
```bash
pip install flyteplugins-snowflake
```
This installs the Snowflake Python connector and the `cryptography` library for key-pair authentication.
## Quick start
Here's a minimal example that runs a SQL query on Snowflake:
```python {hl_lines=[2, 4, 12]}
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
)
count_users = Snowflake(
name="count_users",
query_template="SELECT COUNT(*) FROM users",
plugin_config=config,
output_dataframe_type=DataFrame,
)
```
This defines a task called `count_users` that runs `SELECT COUNT(*) FROM users` on the configured Snowflake instance. When executed, the connector:
1. Connects to Snowflake using the provided configuration
2. Submits the query asynchronously
3. Polls until the query completes or fails
4. Provides a link to the query in the Snowflake dashboard

To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:
```python {hl_lines=3}
import flyte
snowflake_env = flyte.TaskEnvironment.from_task("snowflake_env", count_users)
if __name__ == "__main__":
flyte.init_from_config()
# Run locally (connector runs in-process, requires credentials and packages locally)
run = flyte.with_runcontext(mode="local").run(count_users)
# Run remotely (connector runs on the control plane)
run = flyte.with_runcontext(mode="remote").run(count_users)
print(run.url)
```
> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. Snowflake tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-snowflake` and credentials to be available on your machine. In `remote` mode, the connector runs on the control plane.
## Configuration
The `SnowflakeConfig` dataclass defines the connection settings for your Snowflake instance.
### Required fields
| Field | Type | Description |
| ----------- | ----- | ------------------------------------------------------- |
| `account` | `str` | Snowflake account identifier (e.g. `"myorg-myaccount"`) |
| `database` | `str` | Target database name |
| `schema` | `str` | Target schema name (e.g. `"PUBLIC"`) |
| `warehouse` | `str` | Compute warehouse to use for query execution |
| `user` | `str` | Snowflake username |
### Additional connection parameters
Use `connection_kwargs` to pass any additional parameters supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api). This is a dictionary that gets forwarded directly to `snowflake.connector.connect()`.
Common options include:
| Parameter | Type | Description |
| --------------- | ----- | -------------------------------------------------------------------------- |
| `role` | `str` | Snowflake role to use for the session |
| `authenticator` | `str` | Authentication method (e.g. `"snowflake"`, `"externalbrowser"`, `"oauth"`) |
| `token` | `str` | OAuth token when using `authenticator="oauth"` |
| `login_timeout` | `int` | Timeout in seconds for the login request |
Example with a role:
```python {hl_lines=8}
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
connection_kwargs={
"role": "DATA_ANALYST",
},
)
```
## Authentication
The connector supports two authentication approaches: key-pair authentication, and password-based or other authentication methods provided through `connection_kwargs`.
### Key-pair authentication
Key-pair authentication is the recommended approach for automated workloads. Pass the names of the Flyte secrets containing the private key and optional passphrase:
```python {hl_lines=[5, 6]}
query = Snowflake(
name="secure_query",
query_template="SELECT * FROM sensitive_data",
plugin_config=config,
snowflake_private_key="my-snowflake-private-key",
snowflake_private_key_passphrase="my-snowflake-pk-passphrase",
)
```
The `snowflake_private_key` parameter is the name of the secret (or secret key) that contains your PEM-encoded private key. The `snowflake_private_key_passphrase` parameter is the name of the secret (or secret key) that contains the passphrase, if your key is encrypted. If your key is not encrypted, omit the passphrase parameter.
The connector decodes the PEM key and converts it to DER format for Snowflake authentication.
> [!NOTE]
> If your credentials are stored in a secret group, you can pass `secret_group` to the `Snowflake` task. The plugin expects `snowflake_private_key` and
> `snowflake_private_key_passphrase` to be keys within the same secret group.
### Password authentication
Send the password via `connection_kwargs`:
```python {hl_lines=8}
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
connection_kwargs={
"password": "my-password",
},
)
```
### OAuth authentication
For OAuth-based authentication, specify the authenticator and token:
```python {hl_lines=["8-9"]}
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
connection_kwargs={
"authenticator": "oauth",
"token": "",
},
)
```
## Query templating
Use the `inputs` parameter to define typed inputs for your query. Input values are bound using the `%(param)s` syntax supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api), which handles type conversion and escaping automatically.
### Supported input types
The `inputs` dictionary maps parameter names to Python values. Supported scalar types include `str`, `int`, `float`, and `bool`.
To insert multiple rows in a single query, you can also provide lists as input values. When using list inputs, be sure to set `batch=True` on the `Snowflake` task. This enables automatic batching, where the inputs are expanded and sent as a single multi-row query instead of you having to write multiple individual statements.
### Batched `INSERT` with list inputs
When `batch=True` is enabled, a parameterized `INSERT` query with list inputs is automatically expanded into a multi-row `VALUES` statement.
Example:
```python
query = "INSERT INTO t (a, b) VALUES (%(a)s, %(b)s)"
inputs = {"a": [1, 2], "b": ["x", "y"]}
```
This is expanded into:
```sql
INSERT INTO t (a, b)
VALUES (%(a_0)s, %(b_0)s), (%(a_1)s, %(b_1)s)
```
with the following flattened parameters:
```python
flat_params = {
"a_0": 1,
"b_0": "x",
"a_1": 2,
"b_1": "y",
}
```
#### Constraints
- The query must contain exactly one `VALUES (...)` clause.
- All list inputs must have the same non-zero length.
### Parameterized `SELECT`
```python {hl_lines=[5, 7]}
from flyte.io import DataFrame
events_by_date = Snowflake(
name="events_by_date",
query_template="SELECT * FROM events WHERE event_date = %(event_date)s",
plugin_config=config,
inputs={"event_date": str},
output_dataframe_type=DataFrame,
)
```
You can call the task with the required inputs:
```python {hl_lines=3}
@env.task
async def fetch_events() -> DataFrame:
return await events_by_date(event_date="2024-01-15")
```
### Multiple inputs
You can define multiple input parameters of different types:
```python {hl_lines=["4-8", "12-15"]}
filtered_events = Snowflake(
name="filtered_events",
query_template="""
SELECT * FROM events
WHERE event_date >= %(start_date)s
AND event_date <= %(end_date)s
AND region = %(region)s
AND score > %(min_score)s
""",
plugin_config=config,
inputs={
"start_date": str,
"end_date": str,
"region": str,
"min_score": float,
},
output_dataframe_type=DataFrame,
)
```
> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces, and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.
## Retrieving query results
If your query produces output, set `output_dataframe_type` to capture the results. `output_dataframe_type` accepts `DataFrame` from `flyte.io`. This is a meta-wrapper type that represents tabular results and can be materialized into a concrete DataFrame implementation using `open()` where you specify the target type and `all()`.
```python {hl_lines=13}
from flyte.io import DataFrame
top_customers = Snowflake(
name="top_customers",
query_template="""
SELECT customer_id, SUM(amount) AS total_spend
FROM orders
GROUP BY customer_id
ORDER BY total_spend DESC
LIMIT 100
""",
plugin_config=config,
output_dataframe_type=DataFrame,
)
```
At present, only `pandas.DataFrame` is supported. The results are returned directly when you call the task:
```python {hl_lines=6}
import pandas as pd
@env.task
async def analyze_top_customers() -> dict:
df = await top_customers()
pandas_df = await df.open(pd.DataFrame).all()
total_spend = pandas_df["total_spend"].sum()
return {"total_spend": float(total_spend)}
```
If you specify `pandas.DataFrame` as the `output_dataframe_type`, you do not need to call `open()` and `all()` to materialize the results.
```python {hl_lines=[1, 13, "18-19"]}
import pandas as pd
top_customers = Snowflake(
name="top_customers",
query_template="""
SELECT customer_id, SUM(amount) AS total_spend
FROM orders
GROUP BY customer_id
ORDER BY total_spend DESC
LIMIT 100
""",
plugin_config=config,
output_dataframe_type=pd.DataFrame,
)
@env.task
async def analyze_top_customers() -> dict:
df = await top_customers()
total_spend = df["total_spend"].sum()
return {"total_spend": float(total_spend)}
```
> [!NOTE]
> Be sure to inject the `SNOWFLAKE_PRIVATE_KEY` and `SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` environment variables as secrets into your downstream tasks, as they must have access to Snowflake credentials in order to retrieve the DataFrame results. More on how you can create secrets **Configure tasks > Secrets**.
If you don't need query results (for example, `DDL` statements or `INSERT` queries), omit `output_dataframe_type`.
## End-to-end example
Here's a complete workflow that uses the Snowflake connector as part of a data pipeline. The workflow creates a staging table, inserts records, queries aggregated results and processes them in a downstream task.
```
import flyte
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
connection_kwargs={
"role": "ETL_ROLE",
},
)
# Step 1: Create the staging table if it doesn't exist
create_staging = Snowflake(
name="create_staging",
query_template="""
CREATE TABLE IF NOT EXISTS staging.daily_events (
event_id STRING,
event_date DATE,
user_id STRING,
event_type STRING,
payload VARIANT
)
""",
plugin_config=config,
snowflake_private_key="snowflake",
snowflake_private_key_passphrase="snowflake_passphrase",
)
# Step 2: Insert a record into the staging table
insert_events = Snowflake(
name="insert_event",
query_template="""
INSERT INTO staging.daily_events (event_id, event_date, user_id, event_type)
VALUES (%(event_id)s, %(event_date)s, %(user_id)s, %(event_type)s)
""",
plugin_config=config,
inputs={
"event_id": list[str],
"event_date": list[str],
"user_id": list[str],
"event_type": list[str],
},
snowflake_private_key="snowflake",
snowflake_private_key_passphrase="snowflake_passphrase",
batch=True,
)
# Step 3: Query aggregated results for a given date
daily_summary = Snowflake(
name="daily_summary",
query_template="""
SELECT
event_type,
COUNT(*) AS event_count,
COUNT(DISTINCT user_id) AS unique_users
FROM staging.daily_events
WHERE event_date = %(report_date)s
GROUP BY event_type
ORDER BY event_count DESC
""",
plugin_config=config,
inputs={"report_date": str},
output_dataframe_type=DataFrame,
snowflake_private_key="snowflake",
snowflake_private_key_passphrase="snowflake_passphrase",
)
# Create environments for all Snowflake tasks
snowflake_env = flyte.TaskEnvironment.from_task(
"snowflake_env", create_staging, insert_events, daily_summary
)
# Main pipeline environment that depends on the Snowflake task environments
env = flyte.TaskEnvironment(
name="analytics_env",
resources=flyte.Resources(memory="512Mi"),
image=flyte.Image.from_debian_base(name="analytics").with_pip_packages(
"flyteplugins-snowflake", pre=True
),
secrets=[
flyte.Secret(key="snowflake", as_env_var="SNOWFLAKE_PRIVATE_KEY"),
flyte.Secret(
key="snowflake_passphrase", as_env_var="SNOWFLAKE_PRIVATE_KEY_PASSPHRASE"
),
],
depends_on=[snowflake_env],
)
# Step 4: Process the results in Python
@env.task
async def generate_report(summary: DataFrame) -> dict:
import pandas as pd
df = await summary.open(pd.DataFrame).all()
total_events = df["event_count"].sum()
top_event = df.iloc[0]["event_type"]
return {
"total_events": int(total_events),
"top_event_type": top_event,
"event_types_count": len(df),
}
# Compose the pipeline
@env.task
async def run_daily_pipeline(
event_ids: list[str],
event_dates: list[str],
user_ids: list[str],
event_types: list[str],
) -> dict:
await create_staging()
await insert_events(
event_id=event_ids,
event_date=event_dates,
user_id=user_ids,
event_type=event_types,
)
summary = await daily_summary(report_date=event_dates[0])
return await generate_report(summary=summary)
if __name__ == "__main__":
flyte.init_from_config()
# Run locally
run = flyte.with_runcontext(mode="local").run(
run_daily_pipeline,
event_ids=["event-1", "event-2"],
event_dates=["2023-01-01", "2023-01-02"],
user_ids=["user-1", "user-2"],
event_types=["click", "view"],
)
# Or run remotely
run = flyte.with_runcontext(mode="remote").run(
run_daily_pipeline,
event_ids=["event-1", "event-2"],
event_dates=["2023-01-01", "2023-01-02"],
user_ids=["user-1", "user-2"],
event_types=["click", "view"],
)
print(run.url)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/connectors/snowflake/example.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/spark ===
# Spark
The Spark plugin lets you run [Apache Spark](https://spark.apache.org/) jobs natively on Kubernetes. Flyte manages the full cluster lifecycle: provisioning a transient Spark cluster for each task execution, running the job, and tearing the cluster down on completion.
Under the hood, the plugin uses the [Spark on Kubernetes Operator](https://github.com/GoogleCloudPlatform/spark-on-k8s-operator) to create and manage Spark applications. No external Spark service or long-running cluster is required.
## When to use this plugin
- Large-scale data processing and ETL pipelines
- Jobs that benefit from Spark's distributed execution engine (Spark SQL, PySpark, Spark MLlib)
- Workloads that need Hadoop-compatible storage access (S3, GCS, HDFS)
## Installation
```bash
pip install flyteplugins-spark
```
## Configuration
Create a `Spark` configuration and pass it as `plugin_config` to a `TaskEnvironment`:
```python
from flyteplugins.spark import Spark
spark_config = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_config,
image=image,
)
```
### `Spark` parameters
| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs (e.g., executor memory, cores, instances) |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs (e.g., S3/GCS access settings) |
| `executor_path` | `str` | Path to the Python binary for PySpark executors |
| `applications_path` | `str` | Path to the main Spark application file |
| `driver_pod` | `PodTemplate` | Pod template for the Spark driver pod |
| `executor_pod` | `PodTemplate` | Pod template for the Spark executor pods |
### Accessing the Spark session
Inside a Spark task, the `SparkSession` is available through the task context:
```python
from flyte._context import internal_ctx
@spark_env.task
async def my_spark_task() -> float:
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
# Use spark as a normal SparkSession
df = spark.read.parquet("s3://my-bucket/data.parquet")
return df.count()
```
### Overriding configuration at runtime
You can override Spark configuration for individual task calls using `.override()`:
```python
from copy import deepcopy
updated_config = deepcopy(spark_config)
updated_config.spark_conf["spark.executor.instances"] = "4"
result = await my_spark_task.override(plugin_config=updated_config)()
```
## Example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///
import random
from copy import deepcopy
from operator import add
from flyteplugins.spark.task import Spark
import flyte.remote
from flyte._context import internal_ctx
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
.with_pip_packages("flyteplugins-spark", pre=True)
)
task_env = flyte.TaskEnvironment(
name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
"spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
"spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
plugin_config=spark_conf,
image=image,
depends_on=[task_env],
)
def f(_):
x = random.random() * 2 - 1
y = random.random() * 2 - 1
return 1 if x**2 + y**2 <= 1 else 0
@task_env.task
async def get_pi(count: int, partitions: int) -> float:
return 4.0 * count / partitions
@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
n = 1 * partitions
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
return await get_pi(count, partitions)
@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
updated_spark_conf = deepcopy(spark_conf)
updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_spark_nested)
print(r.name)
print(r.url)
r.wait()
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/spark/spark_example.py*
## API reference
See the [Spark API reference](../../api-reference/integrations/spark/_index) for full details.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb ===
# Weights & Biases
[Weights & Biases](https://wandb.ai) (W&B) is a platform for tracking machine learning experiments, visualizing metrics and optimizing hyperparameters. This plugin integrates W&B with Flyte, enabling you to:
- Automatically initialize W&B runs in your tasks without boilerplate
- Link directly from the Flyte UI to your W&B runs and sweeps
- Share W&B runs across parent and child tasks
- Track distributed training jobs across multiple GPUs and nodes
- Run hyperparameter sweeps with parallel agents
## Installation
```bash
pip install flyteplugins-wandb
```
You also need a W&B API key. Store it as a Flyte secret so your tasks can authenticate with W&B.
## Quick start
Here's a minimal example that logs metrics to W&B from a Flyte task:
```
import flyte
from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init
env = flyte.TaskEnvironment(
name="wandb-example",
image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
"flyteplugins-wandb"
),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_init
@env.task
async def train_model() -> str:
wandb_run = get_wandb_run()
# Your training code here
for epoch in range(10):
loss = 1.0 / (epoch + 1)
wandb_run.log({"epoch": epoch, "loss": loss})
return "Training complete"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
),
).run(train_model)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/quick_start.py*
This example demonstrates the core pattern:
1. **Define a task environment** with the plugin installed and your W&B API key as a secret
2. **Decorate your task** with `@wandb_init` (must be the outermost decorator, above `@env.task`)
3. **Access the run** with `get_wandb_run()` to log metrics
4. **Provide configuration** via `wandb_config()` when running the task
The plugin handles calling `wandb.init()` and `wandb.finish()` for you, and automatically adds a link to the W&B run in the Flyte UI.

## What's next
This integration guide is split into focused sections, depending on how you want to use Weights & Biases with Flyte:
- ****Weights & Biases > Experiments****: Create and manage W&B runs from Flyte tasks.
- ****Weights & Biases > Distributed training****: Track experiments across multi-GPU and multi-node training jobs.
- ****Weights & Biases > Sweeps****: Run hyperparameter searches and manage sweep execution from Flyte tasks.
- ****Weights & Biases > Downloading logs****: Download logs and execution metadata from Weights & Biases.
- ****Weights & Biases > Constraints and best practices****: Learn about limitations, edge cases and recommended patterns.
- ****Weights & Biases > Manual integration****: Use Weights & Biases directly in Flyte tasks without decorators or helpers.
> **π Note**
>
> We've included additional examples developed while testing edge cases of the plugin [here](https://github.com/flyteorg/flyte-sdk/tree/main/plugins/wandb/examples).
## Subpages
- **Weights & Biases > Experiments**
- **Weights & Biases > Distributed training**
- **Weights & Biases > Sweeps**
- **Weights & Biases > Downloading logs**
- **Weights & Biases > Constraints and best practices**
- **Weights & Biases > Manual integration**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/experiments ===
# Experiments
The `@wandb_init` decorator automatically initializes a W&B run when your task executes and finishes it when the task completes. This section covers the different ways to use it.
## Basic usage
Apply `@wandb_init` as the outermost decorator on your task:
```python {hl_lines=1}
@wandb_init
@env.task
async def my_task() -> str:
run = get_wandb_run()
run.log({"metric": 42})
return "done"
```
The decorator:
- Calls `wandb.init()` before your task code runs
- Calls `wandb.finish()` after your task completes (or fails)
- Adds a link to the W&B run in the Flyte UI
You can also use it on synchronous tasks:
```python {hl_lines=[1, 3]}
@wandb_init
@env.task
def my_sync_task() -> str:
run = get_wandb_run()
run.log({"metric": 42})
return "done"
```
## Accessing the run object
Use `get_wandb_run()` to access the current W&B run object:
```python {hl_lines=6}
from flyteplugins.wandb import get_wandb_run
@wandb_init
@env.task
async def train() -> str:
run = get_wandb_run()
# Log metrics
run.log({"loss": 0.5, "accuracy": 0.9})
# Access run properties
print(f"Run ID: {run.id}")
print(f"Run URL: {run.url}")
print(f"Project: {run.project}")
# Log configuration
run.config.update({"learning_rate": 0.001, "batch_size": 32})
return run.id
```
## Parent-child task relationships
When a parent task calls child tasks, the plugin can share the same W&B run across all of them. This is useful for tracking an entire workflow in a single run.
```python {hl_lines=[1, 9, 16]}
@wandb_init
@env.task
async def child_task(x: int) -> int:
run = get_wandb_run()
run.log({"child_metric": x * 2})
return x * 2
@wandb_init
@env.task
async def parent_task() -> int:
run = get_wandb_run()
run.log({"parent_metric": 100})
# Child task shares the parent's run by default
result = await child_task(5)
return result
```
By default (`run_mode="auto"`), child tasks reuse their parent's W&B run. All metrics logged by the parent and children appear in the same run in the W&B UI.
## Run modes
The `run_mode` parameter controls how tasks create or reuse W&B runs. There are three modes:
| Mode | Behavior |
| ---------------- | -------------------------------------------------------------------------- |
| `auto` (default) | Create a new run if no parent run exists, otherwise reuse the parent's run |
| `new` | Always create a new run, even if a parent run exists |
| `shared` | Always reuse the parent's run (fails if no parent run exists) |
### Using `run_mode="new"` for independent runs
```python {hl_lines=1}
@wandb_init(run_mode="new")
@env.task
async def independent_child(x: int) -> int:
run = get_wandb_run()
# This task gets its own separate run
run.log({"independent_metric": x})
return x
@wandb_init
@env.task
async def parent_task() -> str:
run = get_wandb_run()
parent_run_id = run.id
# This child creates its own run
await independent_child(5)
# Parent's run is unchanged
assert run.id == parent_run_id
return parent_run_id
```
### Using `run_mode="shared"` for explicit sharing
```python {hl_lines=1}
@wandb_init(run_mode="shared")
@env.task
async def must_share_run(x: int) -> int:
# This task requires a parent run to exist
# It will fail if called as a top-level task
run = get_wandb_run()
run.log({"shared_metric": x})
return x
```
## Configuration with `wandb_config`
Use `wandb_config()` to configure W&B runs. You can set it at the workflow level or override it for specific tasks, allowing you to provide configuration values at runtime.
### Workflow-level configuration
```python {hl_lines=["5-9"]}
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
tags=["experiment-1", "production"],
config={"model": "resnet50", "dataset": "imagenet"},
),
).run(train_task)
```
### Overriding configuration for child tasks
Use `wandb_config()` as a context manager to override settings for specific child task calls:
```python {hl_lines=[8, 12]}
@wandb_init
@env.task
async def parent_task() -> str:
run = get_wandb_run()
run.log({"parent_metric": 100})
# Override tags and config for this child call
with wandb_config(tags=["special-run"], config={"learning_rate": 0.01}):
await child_task(10)
# Override run_mode for this child call
with wandb_config(run_mode="new"):
await child_task(20) # Gets its own run
return "done"
```
## Using traces with W&B runs
Flyte traces can access the parent task's W&B run without needing the `@wandb_init` decorator. This is useful for helper functions that should log to the same run:
```python {hl_lines=[1, 3]}
@flyte.trace
async def log_validation_metrics(accuracy: float, f1: float):
run = get_wandb_run()
if run:
run.log({"val_accuracy": accuracy, "val_f1": f1})
@wandb_init
@env.task
async def train_and_validate() -> str:
run = get_wandb_run()
# Training loop
for epoch in range(10):
run.log({"train_loss": 1.0 / (epoch + 1)})
# Trace logs to the same run
await log_validation_metrics(accuracy=0.95, f1=0.92)
return "done"
```
> **π Note**
>
> Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/distributed_training ===
# Distributed training
When running distributed training jobs, multiple processes run simultaneously across GPUs. The `@wandb_init` decorator automatically detects distributed training environments and coordinates W&B logging across processes.
The plugin:
- Auto-detects distributed context from environment variables (set by launchers like `torchrun`)
- Controls which processes initialize W&B runs based on the `run_mode` and `rank_scope` parameters
- Generates unique run IDs that distinguish between workers and ranks
- Adds links to W&B runs in the Flyte UI
## Quick start
Here's a minimal single-node example that logs metrics from a distributed training task. By default (`run_mode="auto"`, `rank_scope="global"`), only rank 0 logs to W&B:
```
import flyte
import torch
import torch.distributed
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init
image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
"flyteplugins-wandb", "flyteplugins-pytorch"
)
env = flyte.TaskEnvironment(
name="distributed_env",
image=image,
resources=flyte.Resources(gpu="A100:2"),
plugin_config=Elastic(nproc_per_node=2, nnodes=1),
secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)
@wandb_init
@env.task
def train() -> float:
torch.distributed.init_process_group("nccl")
# Only rank 0 gets a W&B run object; others get None
run = get_wandb_run()
# Simulate training
for step in range(100):
loss = 1.0 / (step + 1)
# Safe to call on all ranks - only rank 0 actually logs
if run:
run.log({"loss": loss, "step": step})
torch.distributed.destroy_process_group()
return loss
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(
custom_context=wandb_config(project="my-project", entity="my-team")
).run(train)
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/distributed_training_quick_start.py*
A few things to note:
1. Use the `Elastic` plugin to configure distributed training (number of processes, nodes)
2. Apply `@wandb_init` as the outermost decorator
3. Check if `run` is not None before logging - only the primary rank has a run object in `auto` mode
> **π Note**
>
> The `if run:` check is always safe regardless of run mode. In `shared` and `new` modes all ranks get a run object, but the check doesn't hurt and keeps your code portable across modes.

## Run modes in distributed training
The `run_mode` parameter controls how W&B runs are created across distributed processes. The behavior differs between single-node (one machine, multiple GPUs) and multi-node (multiple machines) setups.
### Single-node behavior
| Mode | Which ranks log | Result |
| ---------------- | --------------------- | -------------------------------------- |
| `auto` (default) | Only rank 0 | 1 W&B run |
| `shared` | All ranks to same run | 1 W&B run with metrics labeled by rank |
| `new` | Each rank separately | N W&B runs (grouped in UI) |
### Multi-node behavior
For multi-node training, the `rank_scope` parameter controls the granularity of W&B runs:
- **`global`** (default): Treat all workers as one unit
- **`worker`**: Treat each worker/node independently
The combination of `run_mode` and `rank_scope` determines logging behavior:
| `run_mode` | `rank_scope` | Who initializes W&B | W&B Runs | Grouping |
| ---------- | ------------ | ---------------------- | -------- | -------- |
| `auto` | `global` | Global rank 0 only | 1 | - |
| `auto` | `worker` | Local rank 0 per worker | N | - |
| `shared` | `global` | All ranks (shared globally) | 1 | - |
| `shared` | `worker` | All ranks (shared per worker) | N | - |
| `new` | `global` | All ranks | N Γ M | 1 group |
| `new` | `worker` | All ranks | N Γ M | N groups |
Where `N` = number of workers/nodes, `M` = processes per worker.
### Choosing run mode and rank scope
- **`auto`** (recommended): Use when you want clean dashboards with minimal runs. Most metrics (loss, accuracy) are the same across ranks after gradient synchronization, so logging from one rank is sufficient.
- **`shared`**: Use when you need to compare metrics across ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier. Useful for debugging load imbalance or per-GPU throughput.
- **`new`**: Use when you need completely separate runs per GPU, for example to track GPU-specific metrics or compare training dynamics across devices.
For multi-node training:
- Use **`rank_scope="global"`** (default) for most cases. A single consolidated run across all nodes is sufficient since metrics like loss and accuracy converge after gradient synchronization.
- Use **`rank_scope="worker"`** for debugging and per-node analysis. This is useful when you need to inspect data distribution across nodes, compare predictions from different workers, or track metrics on individual batches outside the main node.
## Single-node multi-GPU
For single-node distributed training, configure the `Elastic` plugin with `nnodes=1` and set `nproc_per_node` to your GPU count.
### Basic example with `auto` mode
```python {hl_lines=["6-7", 13, 18, 30]}
import os
import torch
import torch.distributed
import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run
env = flyte.TaskEnvironment(
name="single_node_env",
image=image,
resources=flyte.Resources(gpu="A100:4"),
plugin_config=Elastic(nproc_per_node=4, nnodes=1),
secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)
@wandb_init # run_mode="auto" (default)
@env.task
def train_single_node() -> float:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
run = get_wandb_run()
# Training loop - only rank 0 logs
for epoch in range(10):
loss = train_epoch(model, dataloader, device)
if run:
run.log({"epoch": epoch, "loss": loss})
torch.distributed.destroy_process_group()
return loss
```
### Using `shared` mode for per-rank metrics
When you need to see metrics from all GPUs in a single run, use `run_mode="shared"`:
```python {hl_lines=[3, 13, 19]}
import os
@wandb_init(run_mode="shared")
@env.task
def train_with_per_gpu_metrics() -> float:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
# In shared mode, all ranks get a run object
run = get_wandb_run()
for step in range(1000):
loss, throughput = train_step(model, batch, device)
# Each rank logs with automatic x_label identification
if run:
run.log({
"loss": loss,
"throughput_samples_per_sec": throughput,
"gpu_memory_used": torch.cuda.memory_allocated(device),
})
torch.distributed.destroy_process_group()
return loss
```

In the W&B UI, metrics from each rank appear with distinct labels, allowing you to compare GPU utilization and throughput across devices.

### Using `new` mode for per-rank runs
When you need completely separate W&B runs for each GPU, use `run_mode="new"`. Each rank gets its own run, and runs are grouped together in the W&B UI:
```python {hl_lines=[1, "11-12"]}
@wandb_init(run_mode="new") # Each rank gets its own run
@env.task
def train_per_rank() -> float:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
# ...
# Each rank has its own W&B run
run = get_wandb_run()
# Run IDs: {base}-rank-{rank}
# All runs are grouped under {base} in W&B UI
run.log({"train/loss": loss.item(), "rank": rank})
# ...
```
With `run_mode="new"`:
- Each rank creates its own W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-rank-{rank}`
- All runs are grouped together in the W&B UI for comparison
## Multi-node training with `Elastic`
For multi-node distributed training, set `nnodes` to your node count. The `rank_scope` parameter controls whether you get a single W&B run across all nodes (`global`) or one run per node (`worker`).
### Global scope (default): Single run across all nodes
With `run_mode="auto"` and `rank_scope="global"` (both defaults), only global rank 0 initializes W&B, resulting in a single run for the entire distributed job:
```python {hl_lines=["11-12", "27-30", "35", "59-60", "95-98"]}
import os
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run
image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
"flyteplugins-wandb", "flyteplugins-pytorch", pre=True
)
multi_node_env = flyte.TaskEnvironment(
name="multi_node_env",
image=image,
resources=flyte.Resources(
cpu=(1, 2),
memory=("1Gi", "10Gi"),
gpu="A100:4",
shm="auto",
),
plugin_config=Elastic(
nproc_per_node=4, # GPUs per node
nnodes=2, # Number of nodes
),
secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)
@wandb_init # rank_scope="global" by default β 1 run total
@multi_node_env.task
def train_multi_node(epochs: int, batch_size: int) -> float:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
# Model with DDP
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank])
# Distributed data loading
dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Only global rank 0 gets a W&B run
run = get_wandb_run()
for epoch in range(epochs):
sampler.set_epoch(epoch)
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if run and batch_idx % 100 == 0:
run.log({
"train/loss": loss.item(),
"train/epoch": epoch,
"train/batch": batch_idx,
})
if run:
run.log({"train/epoch_complete": epoch})
# Barrier ensures all ranks finish before cleanup
torch.distributed.barrier()
torch.distributed.destroy_process_group()
return loss.item()
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(
custom_context=wandb_config(
project="multi-node-training",
tags=["distributed", "multi-node"],
)
).run(train_multi_node, epochs=10, batch_size=32)
```
With this configuration:
- Two nodes run the task, each with 4 GPUs (8 total processes)
- Only global rank 0 creates a W&B run
- Run ID follows the pattern `{run_name}-{action_name}`
- The Flyte UI shows a single link to the W&B run
### Worker scope: One run per node
Use `rank_scope="worker"` when you want each node to have its own W&B run for per-node analysis:
```python {hl_lines=[1, 8]}
@wandb_init(rank_scope="worker") # 1 run per worker/node
@multi_node_env.task
def train_per_worker(epochs: int, batch_size: int) -> float:
torch.distributed.init_process_group("nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
# ...
# Local rank 0 of each worker gets a W&B run
run = get_wandb_run()
if run:
# Each worker logs to its own run
run.log({"train/loss": loss.item()})
# ...
```
With `run_mode="auto"`, `rank_scope="worker"`:
- Each node's local rank 0 creates a W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-worker-{worker_index}`
- The Flyte UI shows links to each worker's W&B run

### Shared mode: All ranks log to the same run
Use `run_mode="shared"` when you need metrics from all ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier.
#### Shared + global scope (1 run total)
```python {hl_lines=[1, 7]}
@wandb_init(run_mode="shared") # All ranks log to 1 shared run
@multi_node_env.task
def train_shared_global() -> float:
torch.distributed.init_process_group("nccl")
# ...
# All ranks get a run object, all log to the same run
run = get_wandb_run()
# Each rank logs with automatic x_label identification
run.log({"train/loss": loss.item(), "rank": rank})
# ...
```
#### Shared + worker scope (N runs, 1 per node)
```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="shared", rank_scope="worker") # 1 shared run per worker
@multi_node_env.task
def train_shared_worker() -> float:
torch.distributed.init_process_group("nccl")
# ...
# All ranks get a run object, grouped by worker
run = get_wandb_run()
# Ranks on the same worker share a run
run.log({"train/loss": loss.item(), "local_rank": local_rank})
# ...
```
### New mode: Separate run per rank
Use `run_mode="new"` when you need completely separate runs per GPU. Runs are grouped in the W&B UI for easy comparison.
#### New + global scope (NΓM runs, 1 group)
```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new") # Each rank gets its own run, all in 1 group
@multi_node_env.task
def train_new_global() -> float:
torch.distributed.init_process_group("nccl")
# ...
# Each rank has its own run
run = get_wandb_run()
# Run IDs: {base}-rank-{global_rank}
run.log({"train/loss": loss.item()})
# ...
```
#### New + worker scope (NΓM runs, N groups)
```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new", rank_scope="worker") # Each rank gets own run, grouped per worker
@multi_node_env.task
def train_new_worker() -> float:
torch.distributed.init_process_group("nccl")
# ...
# Each rank has its own run, grouped by worker
run = get_wandb_run()
# Run IDs: {base}-worker-{idx}-rank-{local_rank}
run.log({"train/loss": loss.item()})
# ...
```
## How it works
The plugin automatically detects distributed training by checking environment variables set by distributed launchers like `torchrun`:
| Environment variable | Description |
| -------------------- | -------------------------------------------------------- |
| `RANK` | Global rank across all processes |
| `WORLD_SIZE` | Total number of processes |
| `LOCAL_RANK` | Rank within the current node |
| `LOCAL_WORLD_SIZE` | Number of processes on the current node |
| `GROUP_RANK` | Node/worker index (0 for first node, 1 for second, etc.) |
When these variables are present, the plugin:
1. **Determines which ranks should initialize W&B** based on `run_mode` and `rank_scope`
2. **Generates unique run IDs** that include worker and rank information
4. **Creates UI links** for each W&B run (single link with `rank_scope="global"`, one per worker with `rank_scope="worker"`)
The plugin automatically adapts to your training setup, eliminating the need for manual distributed configuration.
### Run ID patterns
| Scenario | Run ID Pattern | Group |
| ---------------------------- | --------------------------------------------- | ------------------------ |
| Single-node auto/shared | `{base}` | - |
| Single-node new | `{base}-rank-{rank}` | `{base}` |
| Multi-node auto/shared (global) | `{base}` | - |
| Multi-node auto/shared (worker) | `{base}-worker-{idx}` | - |
| Multi-node new (global) | `{base}-rank-{global_rank}` | `{base}` |
| Multi-node new (worker) | `{base}-worker-{idx}-rank-{local_rank}` | `{base}-worker-{idx}` |
Where `{base}` = `{run_name}-{action_name}`
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/sweeps ===
# Sweeps
W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The `@wandb_sweep` decorator creates a sweep and makes it easy to run trials in parallel using Flyte's distributed execution.
## Creating a sweep
Use `@wandb_sweep` to create a W&B sweep when the task executes:
```
import flyte
import wandb
from flyteplugins.wandb import (
get_wandb_sweep_id,
wandb_config,
wandb_init,
wandb_sweep,
wandb_sweep_config,
)
env = flyte.TaskEnvironment(
name="wandb-example",
image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
"flyteplugins-wandb"
),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_init
def objective():
"""Objective function that W&B calls for each trial."""
wandb_run = wandb.run
config = wandb_run.config
# Simulate training with hyperparameters from the sweep
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@wandb_sweep
@env.task
async def run_sweep() -> str:
sweep_id = get_wandb_sweep_id()
# Run 10 trials
wandb.agent(sweep_id, function=objective, count=10)
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64, 128]},
"epochs": {"values": [5, 10, 20]},
},
),
},
).run(run_sweep)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep.py*
The `@wandb_sweep` decorator:
- Creates a W&B sweep when the task starts
- Makes the sweep ID available via `get_wandb_sweep_id()`
- Adds a link to the main sweeps page in the Flyte UI
Use `wandb_sweep_config()` to define the sweep parameters. This is passed to W&B's sweep API.
> **π Note**
>
> Random and Bayesian searches run indefinitely, and the sweep remains in the `Running` state until you stop it.
> You can stop a running sweep from the Weights & Biases UI or from the command line.
## Running parallel agents
Flyte's distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:
```
import asyncio
from datetime import timedelta
import flyte
import wandb
from flyteplugins.wandb import (
get_wandb_sweep_id,
wandb_config,
wandb_init,
wandb_sweep,
wandb_sweep_config,
get_wandb_context,
)
env = flyte.TaskEnvironment(
name="wandb-parallel-sweep-example",
image=flyte.Image.from_debian_base(
name="wandb-parallel-sweep-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_init
def objective():
wandb_run = wandb.run
config = wandb_run.config
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
"""Single agent that runs a subset of trials."""
wandb.agent(
sweep_id, function=objective, count=count, project=get_wandb_context().project
)
return agent_id
@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
"""Orchestrate multiple agents running in parallel."""
sweep_id = get_wandb_sweep_id()
num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent
# Launch agents in parallel, each with its own resources
agent_tasks = [
sweep_agent.override(
resources=flyte.Resources(cpu="2", memory="4Gi"),
retries=3,
timeout=timedelta(minutes=30),
)(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
for i in range(num_agents)
]
await asyncio.gather(*agent_tasks)
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"values": [5, 10, 20]},
},
),
},
).run(
run_parallel_sweep,
total_trials=20,
trials_per_agent=5,
)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/parallel_sweep.py*
This pattern provides:
- **Distributed execution**: Each agent runs on separate compute nodes
- **Resource allocation**: Specify CPU, memory, and GPU per agent
- **Fault tolerance**: Failed agents can retry without affecting others
- **Timeout protection**: Prevent runaway trials
> **π Note**
>
> `run_parallel_sweep` links to the main Weights & Biases sweeps page and `sweep_agent` links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.

## Writing objective functions
The objective function is called by `wandb.agent()` for each trial. It must be a regular Python function decorated with `@wandb_init`:
```python {hl_lines=["1-2", "5-6"]}
@wandb_init
def objective():
"""Objective function for sweep trials."""
# Access hyperparameters from wandb.run.config
run = wandb.run
config = run.config
# Your training code
model = create_model(
learning_rate=config.learning_rate,
hidden_size=config.hidden_size,
)
for epoch in range(config.epochs):
train_loss = train_epoch(model)
val_loss = validate(model)
# Log metrics - W&B tracks these for the sweep
run.log({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
})
# The final val_loss is used by the sweep to rank trials
```
Key points:
- Use `@wandb_init` on the objective function (not `@env.task`)
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()` since this is outside Flyte context)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it
- The function is called multiple times by `wandb.agent()`, once per trial
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/downloading_logs ===
# Downloading logs
This integration enables downloading Weights & Biases run data, including metrics history, summary data, and synced files.
## Automatic download
Set `download_logs=True` to automatically download run data after your task completes:
```python {hl_lines=1}
@wandb_init(download_logs=True)
@env.task
async def train_with_download():
run = get_wandb_run()
for epoch in range(10):
run.log({"loss": 1.0 / (epoch + 1)})
return run.id
```
The downloaded data is traced by Flyte and appears as a `Dir` output in the Flyte UI. Downloaded files include:
- `summary.json`: Final summary metrics
- `metrics_history.json`: Step-by-step metrics history
- Any files synced by W&B (`requirements.txt`, `wandb_metadata.json`, etc.)
You can also set `download_logs=True` in `wandb_config()`:
```python {hl_lines=5}
flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
download_logs=True,
),
).run(train_task)
```

For sweeps, set `download_logs=True` on `@wandb_sweep` or `wandb_sweep_config()` to download all trial data:
```python {hl_lines=1}
@wandb_sweep(download_logs=True)
@env.task
async def run_sweep():
sweep_id = get_wandb_sweep_id()
wandb.agent(sweep_id, function=objective, count=10)
return sweep_id
```

## Accessing run directories during execution
Use `get_wandb_run_dir()` to access the local W&B run directory during task execution. This is useful for writing custom files that get synced to W&B:
```python {hl_lines=[1, 7, "18-19"]}
from flyteplugins.wandb import get_wandb_run_dir
@wandb_init
@env.task
def train_with_artifacts():
run = get_wandb_run()
local_dir = get_wandb_run_dir()
# Train your model
for epoch in range(10):
run.log({"loss": 1.0 / (epoch + 1)})
# Save model checkpoint to the run directory
model_path = f"{local_dir}/model_checkpoint.pt"
torch.save(model.state_dict(), model_path)
# Save custom metrics file
with open(f"{local_dir}/custom_metrics.json", "w") as f:
json.dump({"final_accuracy": 0.95}, f)
return run.id
```
Files written to the run directory are automatically synced to W&B and can be accessed later via the W&B UI or by setting `download_logs=True`.
> **π Note**
>
> `get_wandb_run_dir()` accesses the local directory without making network calls. Files written here may have a brief delay before appearing in the W&B cloud.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/constraints_and_best_practices ===
# Constraints and best practices
## Decorator ordering
`@wandb_init` and `@wandb_sweep` must be the **outermost decorators**, applied after `@env.task`:
```python
# Correct
@wandb_init
@env.task
async def my_task():
...
# Incorrect - will not work
@env.task
@wandb_init
async def my_task():
...
```
## Traces cannot use decorators
Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`:
```python
# Correct
@flyte.trace
async def my_trace():
run = get_wandb_run()
if run:
run.log({"metric": 42})
# Incorrect - don't decorate traces
@wandb_init
@flyte.trace
async def my_trace():
...
```
## Maximum sweep agents
[W&B limits sweeps to a maximum of 20 concurrent agents](https://docs.wandb.ai/models/sweeps/existing-project#3-launch-agents).
## Configuration priority
Configuration is merged with the following priority (highest to lowest):
1. Decorator parameters (`@wandb_init(project="...")`)
2. Context manager (`with wandb_config(...)`)
3. Workflow-level context (`flyte.with_runcontext(custom_context=wandb_config(...))`)
4. Auto-generated values (run ID from Flyte context)
## Run ID generation
When no explicit `id` is provided, the plugin generates run IDs using the pattern:
```
{run_name}-{action_name}
```
This ensures unique, predictable IDs that can be matched between the `Wandb` link class and manual `wandb.init()` calls.
## Sync delay for local files
Files written to the run directory (via `get_wandb_run_dir()`) are synced to W&B asynchronously. There may be a brief delay before they appear in the W&B cloud or can be downloaded via `download_wandb_run_dir()`.
## Shared run mode requirements
When using `run_mode="shared"`, the task requires a parent task to have already created a W&B run. Calling a task with `run_mode="shared"` as a top-level task will fail.
## Objective functions for sweeps
Objective functions passed to `wandb.agent()` should:
- Be regular Python functions (not Flyte tasks)
- Be decorated with `@wandb_init`
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()`)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it
## Error handling
The plugin raises standard exceptions:
- `RuntimeError`: When `download_wandb_run_dir()` is called without a run ID and no active run exists
- `wandb.errors.AuthenticationError`: When `WANDB_API_KEY` is not set or invalid
- `wandb.errors.CommError`: When a run cannot be found in the W&B cloud
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/wandb/manual ===
# Manual integration
If you need more control over W&B initialization, you can use the `Wandb` and `WandbSweep` link classes directly instead of the decorators. This lets you call `wandb.init()` and `wandb.finish()` yourself while still getting automatic links in the Flyte UI.
## Using the Wandb link class
Add a `Wandb` link to your task to generate a link to the W&B run in the Flyte UI:
```
import flyte
import wandb
from flyteplugins.wandb import Wandb
env = flyte.TaskEnvironment(
name="wandb-manual-init-example",
image=flyte.Image.from_debian_base(
name="wandb-manual-init-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@env.task(
links=(
Wandb(
project="my-project",
entity="my-team",
run_mode="new",
# No id parameter - link will auto-generate from run_name-action_name
),
)
)
async def train_model(learning_rate: float) -> str:
ctx = flyte.ctx()
# Generate run ID matching the link's auto-generated ID
run_id = f"{ctx.action.run_name}-{ctx.action.name}"
# Manually initialize W&B
wandb_run = wandb.init(
project="my-project",
entity="my-team",
id=run_id,
config={"learning_rate": learning_rate},
)
# Your training code
for epoch in range(10):
loss = 1.0 / (learning_rate * (epoch + 1))
wandb_run.log({"epoch": epoch, "loss": loss})
# Manually finish the run
wandb_run.finish()
return wandb_run.id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext().run(
train_model,
learning_rate=0.01,
)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/init_manual.py*
### With a custom run ID
If you want to use your own run ID, specify it in both the link and the `wandb.init()` call:
```python {hl_lines=[6, 14]}
@env.task(
links=(
Wandb(
project="my-project",
entity="my-team",
id="my-custom-run-id",
),
)
)
async def train_with_custom_id() -> str:
run = wandb.init(
project="my-project",
entity="my-team",
id="my-custom-run-id", # Must match the link's ID
resume="allow",
)
# Training code...
run.finish()
return run.id
```
### Adding links at runtime with override
You can also add links when calling a task using `.override()`:
```python {hl_lines=9}
@env.task
async def train_model(learning_rate: float) -> str:
# ... training code with manual wandb.init() ...
return run.id
# Add link when running the task
result = await train_model.override(
links=(Wandb(project="my-project", entity="my-team", run_mode="new"),)
)(learning_rate=0.01)
```
## Using the `WandbSweep` link class
Use `WandbSweep` to add a link to a W&B sweep:
```
import flyte
import wandb
from flyteplugins.wandb import WandbSweep
env = flyte.TaskEnvironment(
name="wandb-manual-sweep-example",
image=flyte.Image.from_debian_base(
name="wandb-manual-sweep-example"
).with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
def objective():
with wandb.init(project="my-project", entity="my-team") as wandb_run:
config = wandb_run.config
for epoch in range(config.epochs):
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
wandb_run.log({"epoch": epoch, "loss": loss})
@env.task(
links=(
WandbSweep(
project="my-project",
entity="my-team",
),
)
)
async def manual_sweep() -> str:
# Manually create the sweep
sweep_config = {
"method": "random",
"metric": {"name": "loss", "goal": "minimize"},
"parameters": {
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"value": 10},
},
}
sweep_id = wandb.sweep(sweep_config, project="my-project", entity="my-team")
# Run the sweep
wandb.agent(sweep_id, function=objective, count=10, project="my-project")
return sweep_id
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext().run(manual_sweep)
print(f"run url: {r.url}")
```
*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep_manual.py*
The link will point to the project's sweeps page. If you have the sweep ID, you can specify it in the link:
```python {hl_lines=6}
@env.task(
links=(
WandbSweep(
project="my-project",
entity="my-team",
id="known-sweep-id",
),
)
)
async def resume_sweep() -> str:
# Resume an existing sweep
wandb.agent("known-sweep-id", function=objective, count=10)
return "known-sweep-id"
```
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/codegen ===
# Code generation
The code generation plugin turns natural-language prompts into tested, production-ready Python code.
You describe what the code should do, along with sample data, schema definitions, constraints, and typed inputs/outputs, and the plugin handles the rest: generating code, writing tests, building an isolated [code sandbox](/docs/v2/byoc//user-guide/sandboxing/code-sandboxing) with the right dependencies, running the tests, diagnosing failures, and iterating until everything passes. The result is a validated script you can execute against real data or deploy as a reusable Flyte task.
## Installation
```bash
pip install flyteplugins-codegen
# For Agent mode (Claude-only)
pip install flyteplugins-codegen[agent]
```
## Quick start
```python{hl_lines=[3, 4, 6, 12, 14, "20-25"]}
import flyte
from flyte.io import File
from flyte.sandbox import sandbox_environment
from flyteplugins.codegen import AutoCoderAgent
agent = AutoCoderAgent(model="gpt-4.1", name="summarize-sales")
env = flyte.TaskEnvironment(
name="my-env",
secrets=[flyte.Secret(key="openai_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_debian_base().with_pip_packages(
"flyteplugins-codegen",
),
depends_on=[sandbox_environment],
)
@env.task
async def process_data(csv_file: File) -> tuple[float, int, int]:
result = await agent.generate.aio(
prompt="Read the CSV and compute total_revenue, total_units and row_count.",
samples={"sales": csv_file},
outputs={"total_revenue": float, "total_units": int, "row_count": int},
)
return await result.run.aio()
```
The `depends_on=[sandbox_environment]` declaration is required. It ensures the sandbox runtime is available when dynamically-created sandboxes execute.

## Two execution backends
The plugin supports two backends for generating and validating code. Both share the same `AutoCoderAgent` interface and produce the same `CodeGenEvalResult`.
### LiteLLM (default)
Uses structured-output LLM calls to generate code, detect packages, build sandbox images, run tests, diagnose failures, and iterate. Works with any model that supports structured outputs (GPT-4, Claude, Gemini, etc. via LiteLLM).
```python{hl_lines=[1, 3]}
agent = AutoCoderAgent(
name="my-task",
model="gpt-4.1",
max_iterations=10,
)
```
The LiteLLM backend follows a fixed pipeline:
```mermaid
flowchart TD
A["prompt + samples"] --> B["generate_plan"]
B --> C["generate_code"]
C --> D["detect_packages"]
D --> E["build_image"]
E --> F{skip_tests?}
F -- yes --> G["return result"]
F -- no --> H["generate_tests"]
H --> I["execute_tests"]
I --> J{pass?}
J -- yes --> G
J -- no --> K["diagnose_error"]
K --> L{error type?}
L -- "logic error" --> M["regenerate code"]
L -- "environment error" --> N["add packages, rebuild image"]
L -- "test error" --> O["fix test expectations"]
M --> I
N --> I
O --> I
```
The loop continues until tests pass or `max_iterations` is reached.

### Agent (Claude)
Uses the Claude Agent SDK to autonomously generate, test, and fix code. The agent has access to `Bash`, `Read`, `Write`, and `Edit` tools and decides what to do at each step. Test execution commands (`pytest`) are intercepted and run inside isolated sandboxes.
```python{hl_lines=["3-4"]}
agent = AutoCoderAgent(
name="my-task",
model="claude-sonnet-4-5-20250929",
backend="claude",
)
```
> [!NOTE]
> Agent mode requires `ANTHROPIC_API_KEY` as a Flyte secret and is Claude-only.
**Key differences from LiteLLM:**
| | LiteLLM | Agent |
| --------------------- | --------------------------------- | ---------------------------------------------- |
| **Execution** | Fixed generate-test-fix pipeline | Autonomous agent decides actions |
| **Model support** | Any model with structured outputs | Claude only |
| **Iteration control** | `max_iterations` | `agent_max_turns` |
| **Test execution** | Direct sandbox execution | `pytest` commands intercepted via hooks |
| **Tool safety** | N/A | Commands classified as safe/denied/intercepted |
| **Observability** | Logs + token counts | Full tool call tracing in Flyte UI |
In Agent mode, Bash commands are classified before execution:
- **Safe** (`ls`, `cat`, `grep`, `head`, etc.) β allowed to run directly
- **Intercepted** (`pytest`) β routed to sandbox execution
- **Denied** (`apt`, `pip install`, `curl`, etc.) β blocked for safety
## Providing data
### Sample data
Pass sample data via `samples` as `File` objects or pandas `DataFrame`s. The plugin automatically:
1. Converts DataFrames to CSV files
2. Infers [Pandera](https://pandera.readthedocs.io/) schemas from the data β column types, nullability
3. Parses natural-language `constraints` into Pandera checks (e.g., `"quantity must be positive"` becomes `pa.Check.gt(0)`)
4. Extracts data context β column statistics, distributions, patterns, sample rows
5. Injects all of this into the LLM prompt so the generated code is aware of the exact data structure
Pandera is used purely for prompt enrichment, not runtime validation. The generated code does not import Pandera β it benefits from the LLM knowing the precise data structure. The generated schemas are stored on `result.generated_schemas` for inspection.
```python{hl_lines=[3]}
result = await agent.generate.aio(
prompt="Clean and validate the data, remove duplicates",
samples={"orders": orders_df, "products": products_file},
constraints=["quantity must be positive", "price between 0 and 10000"],
outputs={"cleaned_orders": File},
)
```
### Schema and constraints
Use `schema` to provide free-form context about data formats or target structures (e.g., a database schema). Use `constraints` to declare business rules that the generated code must respect:
```python{hl_lines=["4-17"]}
result = await agent.generate.aio(
prompt=prompt,
samples={"readings": sensor_df},
schema="""Output JSON schema for report_json:
{
"sensor_id": str,
"avg_temp": float,
"min_temp": float,
"max_temp": float,
"avg_humidity": float,
}
""",
constraints=[
"Temperature values must be between -40 and 60 Celsius",
"Humidity values must be between 0 and 100 percent",
"Output report must have one row per unique sensor_id",
],
outputs={
"report_json": str,
"total_anomalies": int,
},
)
```

### Inputs and outputs
Declare `inputs` for non-sample arguments (e.g., thresholds, flags) and `outputs` for the expected result types.
Supported output types: `str`, `int`, `float`, `bool`, `datetime.datetime`, `datetime.timedelta`, `File`.
Sample entries are automatically added as `File` inputs β you do not need to redeclare them.
```python{hl_lines=[4, 5]}
result = await agent.generate.aio(
prompt="Filter transactions above the threshold",
samples={"transactions": tx_file},
inputs={"threshold": float, "include_pending": bool},
outputs={"filtered": File, "count": int},
)
```
## Running generated code
`agent.generate()` returns a `CodeGenEvalResult`. If `result.success` is `True`, the generated code passed all tests and you can execute it against real data. If `max_iterations` (LiteLLM) or `agent_max_turns` (Agent) is reached without tests passing, `result.success` is `False` and `result.error` contains the failure details.
Both `run()` and `as_task()` return output values as a tuple in the order declared in `outputs`. If there is a single output, the value is returned directly (not wrapped in a tuple).
### One-shot execution with `result.run()`
Runs the generated code in a sandbox. If samples were provided during `generate()`, they are used as default inputs.
```python
# Use sample data as defaults
total_revenue, total_units, count = await result.run.aio()
# Override specific inputs
total_revenue, total_units, count = await result.run.aio(threshold=0.5)
# Sync version
total_revenue, total_units, count = result.run()
```
`result.run()` accepts optional configuration:
```python{hl_lines=["4-6"]}
total_revenue, total_units, count = await result.run.aio(
name="execute-on-data",
resources=flyte.Resources(cpu=2, memory="4Gi"),
retries=2,
timeout=600,
cache="auto",
)
```
### Reusable task with `result.as_task()`
Creates a callable sandbox task from the generated code. Useful when you want to run the same generated code against different data.
```python{hl_lines=[1, "6-7", "9-10"]}
task = result.as_task(
name="run-sensor-analysis",
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# Call with sample defaults
report, total_anomalies = await task.aio()
# Call with different data
report, total_anomalies = await task.aio(readings=new_data_file)
```
## Error diagnosis
The LiteLLM backend classifies test failures into three categories and applies targeted fixes:
| Error type | Meaning | Action |
| ------------- | ----------------------------- | ------------------------------------------------ |
| `logic` | Bug in the generated code | Regenerate code with specific patch instructions |
| `environment` | Missing package or dependency | Add the package and rebuild the sandbox image |
| `test_error` | Bug in the generated test | Fix the test expectations |
If the same error persists after a fix, the plugin reclassifies it (e.g., `logic` to `test_error`) to try the other approach.
In Agent mode, the agent diagnoses and fixes issues autonomously based on error output.
## Durable execution
Code generation is expensive β it involves multiple LLM calls, image builds, and sandbox executions. Without durability, a transient failure in the pipeline (network blip, OOM, downstream service error) would force the entire process to restart from scratch: regenerating code, rebuilding images, re-running sandboxes, making additional LLM calls.
Flyte solves this through two complementary mechanisms: **replay logs** and **caching**.
### Replay logs
Flyte maintains a replay log that records every trace and task execution within a run. When a task crashes and retries, the system replays the log from the previous attempt rather than recomputing everything:
- No additional model calls
- No code regeneration
- No sandbox re-execution
- No container rebuilds
The workflow breezes through the earlier steps and resumes from the failure point. This applies as long as the traces and tasks execute in the same order and use the same inputs as the first attempt.
### Caching
Separately, Flyte can cache task results across runs. With `cache="auto"`, sandbox executions (image builds, test runs, code execution) are cached. This is useful when you re-run the same pipeline β not just when recovering from a crash, but across entirely separate invocations with the same inputs.
Together, replay logs handle crash recovery within a run, and caching avoids redundant work across runs.
### Non-determinism in Agent mode
One challenge with agents is that they are inherently non-deterministic β the sequence of actions can vary between runs, which could break replay.
In practice, the codegen agent follows a predictable pattern (write code, generate tests, run tests, inspect results), which works in replay's favor. The plugin also embeds logic that instructs the agent not to regenerate or re-execute steps that already completed successfully in the first run. This acts as an additional safety check alongside the replay log to account for non-determinism.

On the first attempt, the full pipeline runs. If a transient failure occurs, the system instantly replays the traces (which track model calls) and sandbox executions, allowing the pipeline to resume from the point of failure.

## Observability
### LiteLLM backend
- Logs every iteration with attempt count, error type, and package changes
- Tracks total input/output tokens across all LLM calls (available on `result.total_input_tokens` and `result.total_output_tokens`)
- Results include full conversation history for debugging (`result.conversation_history`)
### Agent backend
- Traces each tool call (name + input) via `PostToolUse` hooks
- Traces tool failures via `PostToolUseFailure` hooks
- Traces a summary when the agent finishes (total tool calls, tool distribution, final image/packages)
- Classifies Bash commands as safe, denied, or intercepted (for sandbox execution)
- All traces appear in the Flyte UI
## Examples
### Processing CSVs with different schemas
Generate code that handles varying CSV formats, then run on real data:
```python{hl_lines=[1, 3, 14, 16, 27]}
from flyteplugins.codegen import AutoCoderAgent
agent = AutoCoderAgent(
name="sales-processor",
model="gpt-4.1",
max_iterations=5,
resources=flyte.Resources(cpu=1, memory="512Mi"),
litellm_params={"temperature": 0.2, "max_tokens": 4096},
)
@env.task
async def process_sales(csv_file: File) -> dict[str, float | int]:
result = await agent.generate.aio(
prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
samples={"csv_data": csv_file},
outputs={
"total_revenue": float,
"total_units": int,
"transaction_count": int,
},
)
if not result.success:
raise RuntimeError(f"Code generation failed: {result.error}")
total_revenue, total_units, transaction_count = await result.run.aio()
return {
"total_revenue": total_revenue,
"total_units": total_units,
"transaction_count": transaction_count,
}
```
### DataFrame analysis with constraints
Pass DataFrames directly and enforce business rules with constraints:
```python{hl_lines=[10, "15-19"]}
agent = AutoCoderAgent(
model="gpt-4.1",
name="sensor-analysis",
base_packages=["numpy"],
max_sample_rows=30,
)
@env.task
async def analyze_sensors(sensor_df: pd.DataFrame) -> tuple[File, int]:
result = await agent.generate.aio(
prompt="""Analyze IoT sensor data. For each sensor, calculate mean/min/max
temperature, mean humidity, and count warnings. Output a summary CSV.""",
samples={"readings": sensor_df},
constraints=[
"Temperature values must be between -40 and 60 Celsius",
"Humidity values must be between 0 and 100 percent",
"Output report must have one row per unique sensor_id",
],
outputs={
"report": File,
"total_anomalies": int,
},
)
if not result.success:
raise RuntimeError(f"Code generation failed: {result.error}")
task = result.as_task(
name="run-sensor-analysis",
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
return await task.aio(readings=result.original_samples["readings"])
```
### Agent mode
The same task using Claude as an autonomous agent:
```python{hl_lines=[3]}
agent = AutoCoderAgent(
name="sales-agent",
backend="claude",
model="claude-sonnet-4-5-20250929",
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
@env.task
async def process_sales_with_agent(csv_file: File) -> dict[str, float | int]:
result = await agent.generate.aio(
prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
samples={"csv_data": csv_file},
outputs={
"total_revenue": float,
"total_units": int,
"transaction_count": int,
},
)
if not result.success:
raise RuntimeError(f"Agent code generation failed: {result.error}")
total_revenue, total_units, transaction_count = await result.run.aio()
return {
"total_revenue": total_revenue,
"total_units": total_units,
"transaction_count": transaction_count,
}
```
## Configuration
### LiteLLM parameters
Tune model behavior with `litellm_params`:
```python{hl_lines=["5-8"]}
agent = AutoCoderAgent(
name="my-task",
model="anthropic/claude-sonnet-4-20250514",
api_key="ANTHROPIC_API_KEY",
litellm_params={
"temperature": 0.3,
"max_tokens": 4000,
},
)
```
### Image configuration
Control the registry and Python version for sandbox images:
```python{hl_lines=["6-10"]}
from flyte.sandbox import ImageConfig
agent = AutoCoderAgent(
name="my-task",
model="gpt-4.1",
image_config=ImageConfig(
registry="my-registry.io",
registry_secret="registry-creds",
python_version=(3, 12),
),
)
```
### Skipping tests
Set `skip_tests=True` to skip test generation and execution. The agent still generates code, detects packages, and builds the sandbox image, but does not generate or run tests.
```python{hl_lines=[4]}
agent = AutoCoderAgent(
name="my-task",
model="gpt-4.1",
skip_tests=True,
)
```
> [!NOTE]
> `skip_tests` only applies to LiteLLM mode. In Agent mode, the agent autonomously decides when to test.
### Base packages
Ensure specific packages are always installed in every sandbox:
```python{hl_lines=[4]}
agent = AutoCoderAgent(
name="my-task",
model="gpt-4.1",
base_packages=["numpy", "pandas"],
)
```
## Best practices
- **One agent per task.** Each `generate()` call builds its own sandbox image and manages its own package state. Running multiple agents in the same task can cause resource contention and makes failures harder to diagnose.
- **Keep `cache="auto"` (the default).** Caching flows to all internal sandboxes, making retries near-instant. Use `"disable"` during development if you want fresh executions, or `"override"` to force re-execution and update the cached result.
- **Set `max_iterations` conservatively.** Start with 5-10 iterations. If the model cannot produce correct code in that budget, the prompt or constraints likely need refinement.
- **Provide constraints for data-heavy tasks.** Explicit constraints (e.g., `"quantity must be positive"`) produce better schemas and better generated code.
- **Inspect `result.generated_schemas`.** Review the inferred Pandera schemas to verify the model understood your data structure correctly.
## API reference
### `AutoCoderAgent` constructor
| Parameter | Type | Default | Description |
| ----------------- | ----------------- | -------------- | -------------------------------------------------------------------------------------- |
| `name` | `str` | `"auto-coder"` | Unique name for tracking and image naming |
| `model` | `str` | `"gpt-4.1"` | LiteLLM model identifier |
| `backend` | `str` | `"litellm"` | Execution backend: `"litellm"` or `"claude"` |
| `system_prompt` | `str` | `None` | Custom system prompt override |
| `api_key` | `str` | `None` | Name of the environment variable containing the LLM API key (e.g., `"OPENAI_API_KEY"`) |
| `api_base` | `str` | `None` | Custom API base URL |
| `litellm_params` | `dict` | `None` | Extra LiteLLM params (temperature, max_tokens, etc.) |
| `base_packages` | `list[str]` | `None` | Always-install pip packages |
| `resources` | `flyte.Resources` | `None` | Resources for sandbox execution (default: 1 CPU, 1Gi) |
| `image_config` | `ImageConfig` | `None` | Registry, secret, and Python version |
| `max_iterations` | `int` | `10` | Max generate-test-fix iterations (LiteLLM mode) |
| `max_sample_rows` | `int` | `100` | Rows to sample from data for LLM context |
| `skip_tests` | `bool` | `False` | Skip test generation and execution (LiteLLM mode) |
| `sandbox_retries` | `int` | `0` | Flyte task-level retries for each sandbox execution |
| `timeout` | `int` | `None` | Timeout in seconds for sandboxes |
| `env_vars` | `dict[str, str]` | `None` | Environment variables for sandboxes |
| `secrets` | `list[Secret]` | `None` | Flyte secrets for sandboxes |
| `cache` | `str` | `"auto"` | Cache behavior: `"auto"`, `"override"`, or `"disable"` |
| `agent_max_turns` | `int` | `50` | Max turns when `backend="claude"` |
### `generate()` parameters
| Parameter | Type | Default | Description |
| ------------- | ------------------------------ | -------- | --------------------------------------------------------------------------------------- |
| `prompt` | `str` | required | Natural-language task description |
| `schema` | `str` | `None` | Free-form context about data formats or target structures |
| `constraints` | `list[str]` | `None` | Natural-language constraints (e.g., `"quantity must be positive"`) |
| `samples` | `dict[str, File \| DataFrame]` | `None` | Sample data. DataFrames are auto-converted to CSV files. |
| `inputs` | `dict[str, type]` | `None` | Non-sample input types (e.g., `{"threshold": float}`) |
| `outputs` | `dict[str, type]` | `None` | Output types. Supported: `str`, `int`, `float`, `bool`, `datetime`, `timedelta`, `File` |
### `CodeGenEvalResult` fields
| Field | Type | Description |
| -------------------------- | ------------------------- | --------------------------------------------------------- |
| `success` | `bool` | Whether tests passed |
| `solution` | `CodeSolution` | Generated code (`.code`, `.language`, `.system_packages`) |
| `tests` | `str` | Generated test code |
| `output` | `str` | Test output |
| `exit_code` | `int` | Test exit code |
| `error` | `str \| None` | Error message if failed |
| `attempts` | `int` | Number of iterations used |
| `image` | `str` | Built sandbox image with all dependencies |
| `detected_packages` | `list[str]` | Pip packages detected |
| `detected_system_packages` | `list[str]` | Apt packages detected |
| `generated_schemas` | `dict[str, str] \| None` | Pandera schemas as Python code strings |
| `data_context` | `str \| None` | Extracted data context |
| `original_samples` | `dict[str, File] \| None` | Sample data as Files (defaults for `run()`/`as_task()`) |
| `total_input_tokens` | `int` | Total input tokens across all LLM calls |
| `total_output_tokens` | `int` | Total output tokens across all LLM calls |
| `conversation_history` | `list[dict]` | Full LLM conversation history for debugging |
### `CodeGenEvalResult` methods
| Method | Description |
| ----------------------------------- | ------------------------------------------------------------------ |
| `result.run(**overrides)` | Execute generated code in a sandbox. Sample data used as defaults. |
| `await result.run.aio(**overrides)` | Async version of `run()`. |
| `result.as_task(name, ...)` | Create a reusable callable sandbox task from the generated code. |
Both `run()` and `as_task()` accept optional `name`, `resources`, `retries`, `timeout`, `env_vars`, `secrets`, and `cache` parameters.
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/mlflow ===
# MLflow
The MLflow plugin integrates [MLflow](https://mlflow.org/) experiment tracking with Flyte. It provides a `@mlflow_run` decorator that automatically manages MLflow runs within Flyte tasks, with support for autologging, parent-child run sharing, distributed training, and auto-generated UI links.
The decorator works with both sync and async tasks.
## Installation
```bash
pip install flyteplugins-mlflow
```
Requires `mlflow` and `flyte`.
## Quick start
```python{hl_lines=[3, 9, "13-16", 22]}
import flyte
import mlflow
from flyteplugins.mlflow import mlflow_run, get_mlflow_run
env = flyte.TaskEnvironment(
name="mlflow-tracking",
resources=flyte.Resources(cpu=1, memory="500Mi"),
image=flyte.Image.from_debian_base(name="mlflow_example").with_pip_packages(
"flyteplugins-mlflow"
),
)
@mlflow_run(
tracking_uri="http://localhost:5000",
experiment_name="my-experiment",
)
@env.task
async def train_model(learning_rate: float) -> str:
mlflow.log_param("lr", learning_rate)
mlflow.log_metric("loss", 0.42)
run = get_mlflow_run()
return run.info.run_id
```


> [!NOTE]
> `@mlflow_run` must be the outermost decorator, before `@env.task`:
>
> ```python{hl_lines=["1-2"]}
> @mlflow_run # outermost
> @env.task # innermost
> async def my_task(): ...
> ```
## Autologging
Enable MLflow's autologging to automatically capture parameters, metrics, and models without manual `mlflow.log_*` calls.
### Generic autologging
```python{hl_lines=[1]}
@mlflow_run(autolog=True)
@env.task
async def train():
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y) # Parameters, metrics, and model are logged automatically
```
### Framework-specific autologging
Pass `framework` to use a framework-specific autolog implementation:
```python{hl_lines=[3]}
@mlflow_run(
autolog=True,
framework="sklearn",
log_models=True,
log_datasets=False,
)
@env.task
async def train_sklearn():
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
```
Supported frameworks include any framework with an `mlflow.{framework}.autolog()` function. You can find the full list of supported frameworks [here](https://mlflow.org/docs/latest/ml/tracking/autolog/#supported-libraries).
You can pass additional autolog parameters via `autolog_kwargs`:
```python{hl_lines=[4]}
@mlflow_run(
autolog=True,
framework="pytorch",
autolog_kwargs={"log_every_n_epoch": 5},
)
@env.task
async def train_pytorch():
...
```

## Run modes
The `run_mode` parameter controls how MLflow runs are created and shared across tasks:
| Mode | Behavior |
| ------------------ | --------------------------------------------------------------------- |
| `"auto"` (default) | Reuse the parent's run if one exists, otherwise create a new run |
| `"new"` | Always create a new independent run |
| `"nested"` | Create a new run nested under the parent via `mlflow.parentRunId` tag |
### Sharing a run across tasks
With `run_mode="auto"` (the default), child tasks reuse the parent's MLflow run:
```python{hl_lines=[1, 5, 7]}
@mlflow_run
@env.task
async def parent_task():
mlflow.log_param("stage", "parent")
await child_task() # Shares the same MLflow run
@mlflow_run
@env.task
async def child_task():
mlflow.log_metric("child_metric", 1.0) # Logged to the parent's run
```
### Creating independent runs
Use `run_mode="new"` when a task should always create its own top-level MLflow run, completely independent of any parent:
```python{hl_lines=[1]}
@mlflow_run(run_mode="new")
@env.task
async def standalone_experiment():
mlflow.log_param("experiment_type", "baseline")
mlflow.log_metric("accuracy", 0.95)
```
### Nested runs
Use `run_mode="nested"` to create a child run that appears under the parent in the MLflow UI. This works across processes and containers via the `mlflow.parentRunId` tag.

This is the recommended pattern for hyperparameter optimization, where each trial should be tracked as a child of the parent study run:
```python{hl_lines=[1, 2, 15, "22-25"]}
from flyteplugins.mlflow import Mlflow
@mlflow_run(run_mode="nested")
@env.task(links=[Mlflow()])
async def run_trial(trial_number: int, n_estimators: int, max_depth: int) -> float:
"""Each trial creates a nested MLflow run under the parent."""
mlflow.log_params({"n_estimators": n_estimators, "max_depth": max_depth})
mlflow.log_param("trial_number", trial_number)
model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)
model.fit(X_train, y_train)
rmse = float(np.sqrt(mean_squared_error(y_val, model.predict(X_val))))
mlflow.log_metric("rmse", rmse)
return rmse
@mlflow_run
@env.task
async def hpo_search(n_trials: int = 30) -> str:
"""Parent run tracks the overall study."""
run = get_mlflow_run()
mlflow.log_param("n_trials", n_trials)
# Run trials in parallel β each gets a nested MLflow run
rmses = await asyncio.gather(
*(run_trial(trial_number=i, **params) for i, params in enumerate(trial_params))
)
mlflow.log_metric("best_rmse", min(rmses))
return run.info.run_id
```

## Workflow-level configuration
Use `mlflow_config()` with `flyte.with_runcontext()` to set MLflow configuration for an entire workflow. All `@mlflow_run`-decorated tasks in the workflow inherit these settings:
```python{hl_lines=[1, "4-8"]}
from flyteplugins.mlflow import mlflow_config
r = flyte.with_runcontext(
custom_context=mlflow_config(
tracking_uri="http://localhost:5000",
experiment_id="846992856162999",
tags={"team": "ml"},
)
).run(train_model, learning_rate=0.001)
```
This eliminates the need to repeat `tracking_uri` and experiment settings on every `@mlflow_run` decorator.
### Per-task overrides
Use `mlflow_config()` as a context manager inside a task to override configuration for specific child tasks:
```python{hl_lines=[6]}
@mlflow_run
@env.task
async def parent_task():
await shared_child() # Inherits parent config
with mlflow_config(run_mode="new", tags={"role": "independent"}):
await independent_child() # Gets its own run
```
### Configuration priority
Settings are resolved in priority order:
1. Explicit `@mlflow_run` decorator arguments
2. `mlflow_config()` context configuration
3. Environment variables (for `tracking_uri`)
4. MLflow defaults
## Distributed training
In distributed training, only rank 0 logs to MLflow by default. The plugin detects rank automatically from the `RANK` environment variable:
```python{hl_lines=[1, "4-6"]}
@mlflow_run
@env.task
async def distributed_train():
# Only rank 0 creates an MLflow run and logs metrics.
# Other ranks execute the task function directly without
# creating an MLflow run or incurring any MLflow overhead.
...
```
On non-rank-0 workers, no MLflow run is created and `get_mlflow_run()` returns `None`. The task function still executes normally β only the MLflow instrumentation is skipped.

You can also set rank explicitly:
```python{hl_lines=[1]}
@mlflow_run(rank=0)
@env.task
async def train():
...
```
## MLflow UI links
The `Mlflow` link class displays links to the MLflow UI in the Flyte UI.
Since the MLflow run is created inside the task at execution time, the run URL cannot be determined before the task starts. Links are only shown when a run URL is already available from context, either because a parent task created the run, or because an explicit URL is provided.
The recommended pattern is for the parent task to create the MLflow run, and child tasks that inherit the run (via `run_mode="auto"`) display the link to that run. For nested runs (`run_mode="nested"`), children display a link to the parent run.
### Setup
Set `link_host` via `mlflow_config()` and attach `Mlflow()` links to child tasks:
```python{hl_lines=[4, 17]}
from flyteplugins.mlflow import Mlflow, mlflow_config
@mlflow_run
@env.task(links=[Mlflow()])
async def child_task():
... # Link points to the parent's MLflow run
@mlflow_run
@env.task
async def parent_task():
await child_task()
if __name__ == "__main__":
r = flyte.with_runcontext(
custom_context=mlflow_config(
tracking_uri="http://localhost:5000",
link_host="http://localhost:5000",
)
).run(parent_task)
```
> [!NOTE]
> `Mlflow()` is instantiated without a `link` argument because the URL is auto-generated at runtime. When the parent task creates an MLflow run, the plugin builds the URL from `link_host` and the run's experiment/run IDs, then propagates it to child tasks via the Flyte context. Passing an explicit `link` would bypass this auto-generation.
### Custom URL templates
The default link format is:
```
{host}/#/experiments/{experiment_id}/runs/{run_id}
```
For platforms like Databricks that use a different URL structure, provide a custom template:
```python{hl_lines=[3]}
mlflow_config(
link_host="https://dbc-xxx.cloud.databricks.com",
link_template="{host}/ml/experiments/{experiment_id}/runs/{run_id}",
)
```
### Explicit links
If you know the run URL ahead of time, you can set it directly:
```python{hl_lines=[1]}
@env.task(links=[Mlflow(link="https://mlflow.example.com/#/experiments/1/runs/abc123")])
async def my_task():
...
```
### Link behavior by run mode
| Run mode | Link behavior |
| ---------- | ---------------------------------------------------------------------------------------------- |
| `"auto"` | Parent link propagates to child tasks sharing the run |
| `"new"` | Parent link is cleared; no link is shown until the task's own run is available to its children |
| `"nested"` | Parent link is kept and renamed to "MLflow (parent)" |
## Automatic Flyte tags
When running inside Flyte, the plugin automatically tags MLflow runs with execution metadata:
| Tag | Description |
| ------------------- | ---------------- |
| `flyte.action_name` | Task action name |
| `flyte.run_name` | Flyte run name |
| `flyte.project` | Flyte project |
| `flyte.domain` | Flyte domain |
These tags are merged with any user-provided tags.
## API reference
### `mlflow_run` and `mlflow_config`
`mlflow_run` is a decorator that manages MLflow runs for Flyte tasks. `mlflow_config` creates workflow-level configuration or per-task overrides. Both accept the same core parameters:
| Parameter | Type | Default | Description |
| ----------------- | ---------------- | -------- | ----------------------------------------------------------------------------- |
| `run_mode` | `str` | `"auto"` | `"auto"`, `"new"`, or `"nested"` |
| `tracking_uri` | `str` | `None` | MLflow tracking server URL |
| `experiment_name` | `str` | `None` | MLflow experiment name (raises `ValueError` if combined with `experiment_id`) |
| `experiment_id` | `str` | `None` | MLflow experiment ID (raises `ValueError` if combined with `experiment_name`) |
| `run_name` | `str` | `None` | Human-readable run name (raises `ValueError` if combined with `run_id`) |
| `run_id` | `str` | `None` | Explicit MLflow run ID (raises `ValueError` if combined with `run_name`) |
| `tags` | `dict[str, str]` | `None` | Tags for the run |
| `autolog` | `bool` | `False` | Enable MLflow autologging |
| `framework` | `str` | `None` | Framework for autolog (e.g. `"sklearn"`, `"pytorch"`) |
| `log_models` | `bool` | `None` | Log models automatically (requires `autolog`) |
| `log_datasets` | `bool` | `None` | Log datasets automatically (requires `autolog`) |
| `autolog_kwargs` | `dict` | `None` | Extra parameters for `mlflow.autolog()` |
Additional keyword arguments are passed to `mlflow.start_run()`.
`mlflow_run` also accepts:
| Parameter | Type | Default | Description |
| --------- | ----- | ------- | -------------------------------------------------------- |
| `rank` | `int` | `None` | Process rank for distributed training (only rank 0 logs) |
`mlflow_config` also accepts:
| Parameter | Type | Default | Description |
| --------------- | ----- | ------- | --------------------------------------------------------------------------- |
| `link_host` | `str` | `None` | MLflow UI host for auto-generating links |
| `link_template` | `str` | `None` | Custom URL template (placeholders: `{host}`, `{experiment_id}`, `{run_id}`) |
### `get_mlflow_run`
Returns the current `mlflow.ActiveRun` if within a `@mlflow_run`-decorated task. Returns `None` otherwise.
```python
from flyteplugins.mlflow import get_mlflow_run
run = get_mlflow_run()
if run:
print(run.info.run_id)
```
### `get_mlflow_context`
Returns the current `mlflow_config` settings from the Flyte context, or `None` if no MLflow configuration is set. Useful for inspecting the inherited configuration inside a task:
```python
from flyteplugins.mlflow import get_mlflow_context
@mlflow_run
@env.task
async def my_task():
config = get_mlflow_context()
if config:
print(config.tracking_uri, config.experiment_id)
```
### `Mlflow`
Link class for displaying MLflow UI links in the Flyte console.
| Field | Type | Default | Description |
| ------ | ----- | ---------- | --------------------------------------- |
| `name` | `str` | `"MLflow"` | Display name for the link |
| `link` | `str` | `""` | Explicit URL (bypasses auto-generation) |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference ===
# Reference
This section provides the reference material for the Flyte SDK and CLI.
To get started, add `flyte` to your project
```shell
$ uv pip install --no-cache --upgrade flyte
```
This will install both the Flyte SDK and CLI.
### **Flyte SDK**
The Flyte SDK provides the core Python API for building workflows and apps on your Union instance.
### **Flyte CLI**
The Flyte CLI is the command-line interface for interacting with your Union instance.
### **Migration from Flyte 1 to Flyte 2**
Comprehensive reference for migrating Flyte 1 workflows to Flyte 2.
## Subpages
- **LLM-optimized documentation**
- **Migration from Flyte 1 to Flyte 2**
- **Flyte CLI**
- **Flyte SDK**
- **Integrations**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-context ===
# LLM-optimized documentation
This site provides LLM-optimized documentation at four levels of granularity,
designed for use by AI coding agents such as
[Claude Code](https://docs.anthropic.com/en/docs/claude-code),
[Cursor](https://www.cursor.com/),
[Windsurf](https://windsurf.com/),
and similar tools.
These files also follow the [`llms.txt` convention](https://llmstxt.org/),
making them discoverable by AI search engines.
Every page on the site also has an **LLM-optimized** section in the right-hand sidebar
that points to:
* This "LLM-optimized documentation" page (for explanation).
* An LLM-optimized version of that page.
* An LLM-optimized single file containing the whole section (only on top pages of key sections).
* The full site index for LLMs.
All links within LLM-optimized files use absolute URLs (`https://www.union.ai/docs/...`),
so files work correctly when copied locally and used outside the docs site.
## Per-page Markdown (`page.md`)
Every page on this site has a parallel LLM-optimized version in clean Markdown,
accessible at the same URL path with `/page.md` appended and via the "**This page**" link in the "**LLM-optimized**" section of the right sidebar.
For example, this page is at:
* **LLM-optimized documentation**
and its LLM-optimized version is at:
* **LLM-optimized documentation**
Section landing pages include a `## Subpages` table listing child pages with their H2/H3 headings,
making it easy to identify the right page to fetch.
## Section bundles (`section.md`)
For key documentation sections, a curated bundle file concatenates all pages in the section
into a single `section.md` file.
These are accessible at the same URL path as the top page of the section, with `/section.md` appended and via the "**This section in one file**" link in the "**LLM-optimized**" section of the right sidebar.
These `section.md` files are sized to fit within modern LLM context windows
and are ideal for pasting into a prompt or adding to project context.
Available bundle files:
{{< llm-readable-list >}}
## Page index (`llms.txt`)
The `llms.txt` file is a compact index of all LLM-optimized pages, organized by section.
Each page entry includes the H2/H3 headings found on that page, so an agent can identify
the right page to fetch without downloading it first.
Sections that have a `section.md` bundle are marked in the index.
Download it and append its contents to the `AGENTS.md`, `CLAUDE.md` or similar file in your project root.
Make sure you append the index into a file that is **loaded into context by default** by your coding tool.
Adding it as a skill or tool is less effective because the agent must decide to load it
rather than having the information always available.
* [`llms.txt`](https://www.union.ai/docs/v2/flyte/llms.txt) (~32K tokens)
> [!NOTE]
> You are viewing the **Flyte OSS** docs.
> To get the `llms.txt` for a different product variant, use the variant selector at the top of the page.
## Full documentation (`llms-full.txt`)
The `llms-full.txt` file contains the entire Flyte version 2.0 documentation as a single Markdown file.
This file is very large and is not suitable for direct inclusion in an LLM context window,
but it may be useful for RAG-based tools.
* [`llms-full.txt`](https://www.union.ai/docs/v2/flyte/llms-full.txt) (~1.4M tokens)
> [!NOTE]
> You are viewing the **Flyte OSS** docs.
> To get the `llms-full.txt` for a different product variant, use the variant selector at the top of the page.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration ===
# Migration from Flyte 1 to Flyte 2
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
This section provides a comprehensive reference for migrating Flyte 1 (flytekit) workflows to Flyte 2 (flyte SDK).
For a quick-start overview of the migration process, see **From Flyte 1 to 2 > Migration from Flyte 1 to Flyte 2** in the User Guide.
## Key API changes at a glance
| Use case | Flyte 1 | Flyte 2 |
| ----------------------------- | --------------------------- | --------------------------------------- |
| Environment management | N/A | `TaskEnvironment` |
| Perform basic computation | `@task` | `@env.task` |
| Combine tasks into a workflow | `@workflow` | `@env.task` |
| Create dynamic workflows | `@dynamic` | `@env.task` |
| Fanout parallelism | `flytekit.map` | Python `for` loop with `asyncio.gather` |
| Conditional execution | `flytekit.conditional` | Python `if-elif-else` |
| Catching workflow failures | `@workflow(on_failure=...)` | Python `try-except` |
## Topics
### **Migration from Flyte 1 to Flyte 2 > Philosophy and imports**
Key paradigm shifts and package import mapping from flytekit to flyte.
### **Migration from Flyte 1 to Flyte 2 > Container images**
Migrate from ImageSpec to flyte.Image with the fluent builder API.
### **Migration from Flyte 1 to Flyte 2 > Configuration and CLI**
Config file format changes and CLI command mapping.
### **Migration from Flyte 1 to Flyte 2 > Tasks and workflows**
Migrate @task, @workflow, and @dynamic to TaskEnvironment and @env.task.
### **Migration from Flyte 1 to Flyte 2 > Secrets, resources, and caching**
Updated patterns for secrets access, resource configuration, and caching.
### **Migration from Flyte 1 to Flyte 2 > Parallelism and async**
Migrate map_task to flyte.map and asyncio.gather patterns.
### **Migration from Flyte 1 to Flyte 2 > Triggers and dynamic workflows**
Migrate LaunchPlan schedules to Triggers and @dynamic to regular tasks.
### **Migration from Flyte 1 to Flyte 2 > Examples and common gotchas**
Complete migration examples and common pitfalls to avoid.
## Subpages
- **Migration from Flyte 1 to Flyte 2 > Philosophy and imports**
- **Migration from Flyte 1 to Flyte 2 > Container images**
- **Migration from Flyte 1 to Flyte 2 > Configuration and CLI**
- **Migration from Flyte 1 to Flyte 2 > Tasks and workflows**
- **Migration from Flyte 1 to Flyte 2 > Secrets, resources, and caching**
- **Migration from Flyte 1 to Flyte 2 > Parallelism and async**
- **Migration from Flyte 1 to Flyte 2 > Triggers and dynamic workflows**
- **Migration from Flyte 1 to Flyte 2 > Examples and common gotchas**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/overview ===
# Philosophy and imports
## Key paradigm shifts
| Concept | Flyte 1 (flytekit) | Flyte 2 (flyte) |
|---------|--------------|------------|
| Workflow definition | `@workflow` decorator (DSL-constrained) | Tasks calling tasks (pure Python) |
| Task configuration | Per-task decorator parameters | `TaskEnvironment` (shared config) |
| Parallelism | `map_task()` function | `flyte.map()` or `asyncio.gather()` |
| Conditionals | `flytekit.conditional()` | Native Python `if/else` |
| Error handling | Decorator-based retries | Python `try/except` + retries |
| Execution model | Static DAG compilation | Dynamic pure Python execution |
## What Flyte 2 eliminates
- **`@workflow` decorator**: No longer exists. Workflows are just tasks that call other tasks.
- **`@dynamic` decorator**: No longer needed. All tasks can have dynamic behavior.
- **DSL constraints**: No more restrictions on what Python constructs you can use.
- **Separate workflow/task execution contexts**: Everything runs as a task.
## What Flyte 2 introduces
- **`TaskEnvironment`**: Centralized configuration for groups of tasks.
- **Native async support**: First-class `async/await` with distributed execution.
- **`flyte.map()`**: Simplified parallel execution with generator support.
- **`Trigger`**: Task-based scheduling (replaces LaunchPlan schedules).
- **Pure Python workflows**: Full Python flexibility in orchestration logic.
For more on the pure Python model, see [Pure Python](../../user-guide/flyte-2/pure-python).
For more on the async model, see [Asynchronous model](../../user-guide/flyte-2/async).
## Package imports
### Basic import changes
### Flyte 1
```python
import flytekit
from flytekit import task, workflow, dynamic, map_task
from flytekit import ImageSpec, Resources, Secret
from flytekit import current_context, LaunchPlan, CronSchedule
```
### Flyte 2
```python
import flyte
from flyte import TaskEnvironment, Resources, Secret
from flyte import Image, Trigger, Cron
```
### Import mapping table
| Flyte 1 import | Flyte 2 import | Notes |
|-----------|-----------|-------|
| `flytekit.task` | `env.task` | Decorator from TaskEnvironment |
| `flytekit.workflow` | `env.task` | Workflows are now tasks |
| `flytekit.dynamic` | `env.task` | All tasks can be dynamic |
| `flytekit.map_task` | `flyte.map` / `asyncio.gather` | Different API |
| `flytekit.ImageSpec` | `flyte.Image` | Different API |
| `flytekit.Resources` | `flyte.Resources` | Similar API |
| `flytekit.Secret` | `flyte.Secret` | Different access pattern |
| `flytekit.current_context()` | `flyte.ctx()` | Different API |
| `flytekit.LaunchPlan` | `flyte.Trigger` | Different concept |
| `flytekit.CronSchedule` | `flyte.Cron` | Used with Trigger |
| `flytekit.conditional` | Native `if/else` | No longer needed |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/images ===
# Container images
Flyte 1's `ImageSpec` is replaced by Flyte 2's `flyte.Image` with a fluent builder API.
## Basic migration
### Flyte 1
```python
from flytekit import ImageSpec
image_spec = ImageSpec(
name="my-image",
registry="ghcr.io/myorg",
python_version="3.11",
packages=["pandas", "numpy"],
apt_packages=["curl", "git"],
env={"MY_VAR": "value"},
)
@task(container_image=image_spec)
def my_task(): ...
```
### Flyte 2
```python
from flyte import Image, TaskEnvironment
image = (
Image.from_debian_base(
name="my-image",
registry="ghcr.io/myorg",
python_version=(3, 11),
)
.with_pip_packages("pandas", "numpy")
.with_apt_packages("curl", "git")
.with_env_vars({"MY_VAR": "value"})
)
env = TaskEnvironment(name="my_env", image=image)
@env.task
def my_task(): ...
```
## Image constructor methods
| Method | Description | Use case |
|--------|-------------|----------|
| `Image.from_debian_base()` | Start from Flyte's Debian base | Most common, includes Flyte SDK |
| `Image.from_base(image_uri)` | Start from any existing image | Custom base images |
| `Image.from_dockerfile(path)` | Build from Dockerfile | Complex custom builds |
| `Image.from_uv_script(path)` | Build from UV script | UV-based projects |
## Image builder methods (chainable)
```python
image = (
Image.from_debian_base(
python_version=(3, 12),
registry="ghcr.io/myorg",
name="my-image",
)
# Python packages
.with_pip_packages("pandas", "numpy>=1.24.0", pre=True)
.with_requirements(Path("requirements.txt"))
.with_uv_project(Path("pyproject.toml"))
.with_poetry_project(Path("pyproject.toml"))
# System packages
.with_apt_packages("curl", "git", "build-essential")
# Custom commands
.with_commands([
"mkdir -p /app/data",
"chmod +x /app/scripts/*.sh",
])
# Files
.with_source_file(Path("config.yaml"), dst="/app/config.yaml")
.with_source_folder(Path("./src"), dst="/app/src")
.with_dockerignore(Path(".dockerignore"))
# Environment
.with_env_vars({"LOG_LEVEL": "INFO", "WORKERS": "4"})
.with_workdir("/app")
)
```
## Builder configuration (local vs remote)
Flyte 2 supports two build modes:
**Local builder** (default): Builds using local Docker and pushes to registry. Requires Docker installed and authenticated to registry.
**Remote builder** (Union instances): Builds on Union's ImageBuilder. No local Docker required. Faster in CI/CD.
```yaml
# In config file
image:
builder: local # or "remote"
```
```python
# Or via code
flyte.init(image_builder="local") # or "remote"
flyte.init_from_config(image_builder="local") # or "remote"
```
## Private registry with secrets
### Flyte 1
```python
image_spec = ImageSpec(
registry="private.registry.com",
registry_config="/path/to/config.json",
)
```
### Flyte 2
First create the secret:
```shell
flyte create secret --type image_pull my-registry-secret --from-file ~/.docker/config.json
```
Then reference it in the image:
```python
image = Image.from_debian_base(
registry="private.registry.com",
name="my-image",
registry_secret="my-registry-secret",
)
```
## Parameter mapping
| Flyte 1 ImageSpec | Flyte 2 Image | Notes |
|--------------|----------|-------|
| `name` | `name` (in constructor) | Same |
| `registry` | `registry` (in constructor) | Same |
| `python_version` | `python_version` (tuple) | `"3.11"` becomes `(3, 11)` |
| `packages` | `.with_pip_packages()` | Method instead of param |
| `apt_packages` | `.with_apt_packages()` | Method instead of param |
| `conda_packages` | N/A | Use micromamba or custom base |
| `requirements` | `.with_requirements()` | Supports txt, poetry.lock, uv.lock |
| `env` | `.with_env_vars()` | Method instead of param |
| `commands` | `.with_commands()` | Method instead of param |
| `copy` | `.with_source_file/folder()` | More explicit methods |
| `source_root` | `.with_source_folder()` | Method instead of param |
| `pip_index` | `index_url` param in `.with_pip_packages()` | Moved to method |
| `pip_extra_index_url` | `extra_index_urls` param | Moved to method |
| `base_image` | `Image.from_base()` | Different constructor |
| `builder` | Config file or `flyte.init()` | Global setting |
| `platform` | `platform` (in constructor) | Tuple: `("linux/amd64", "linux/arm64")` |
For full details on container images in Flyte 2, see [Container images](../../user-guide/task-configuration/container-images).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/configuration-and-cli ===
# Configuration and CLI
## Configuration files
### Config file location
| Version | Default location | Environment variable |
|---------|-----------------|---------------------|
| Flyte 1 | `~/.flyte/config.yaml` | `FLYTECTL_CONFIG` |
| Flyte 2 | `~/.flyte/config.yaml` | `FLYTE_CONFIG` |
### Config format
### Flyte 1
```yaml
union:
connection:
host: dns:///your-cluster.hosted.unionai.cloud
insecure: false
auth:
type: Pkce
admin:
endpoint: dns:///your-cluster.hosted.unionai.cloud
insecure: false
authType: Pkce
```
### Flyte 2
```yaml
admin:
endpoint: dns:///your-cluster.hosted.unionai.cloud
image:
builder: remote # or "local"
task:
domain: development
org: your-org
project: your-project
```
### Key config differences
| Setting | Flyte 1 location | Flyte 2 location |
|---------|-------------|-------------|
| Endpoint | `admin.endpoint` or `union.connection.host` | `admin.endpoint` |
| Auth type | `admin.authType` or `union.auth.type` | Generally auto-detected (PKCE default) |
| Project | CLI flag `-p` | `task.project` (can set default) |
| Domain | CLI flag `-d` | `task.domain` (can set default) |
| Organization | CLI flag `--org` | `task.org` (can set default) |
| Image builder | N/A | `image.builder` (`local` or `remote`) |
### Specifying config via CLI
### Flyte 1
```shell
pyflyte --config ~/.flyte/config.yaml run ...
```
### Flyte 2
```shell
flyte --config ~/.flyte/config.yaml run ...
flyte -c ~/.flyte/config.yaml run ...
```
### Specifying config in code
```python
import flyte
# From config file
flyte.init_from_config() # Auto-discovers config
flyte.init_from_config("path/to/config.yaml") # Explicit path
# Programmatic configuration
flyte.init(
endpoint="flyte.example.com",
insecure=False,
project="my-project",
domain="development",
)
```
## CLI commands
### Command mapping
| Flyte 1 command | Flyte 2 command | Notes |
|------------|------------|-------|
| `pyflyte run` | `flyte run` | Similar but different flags |
| `pyflyte run --remote` | `flyte run` | Remote is default in Flyte 2 |
| `pyflyte run` (no --remote) | `flyte run --local` | Local execution |
| `pyflyte register` | `flyte deploy` | Different concept |
| `pyflyte package` | N/A | Not needed in Flyte 2 |
| `pyflyte serialize` | N/A | Not needed in Flyte 2 |
### Running tasks
### Flyte 1
```shell
# Run locally
pyflyte run my_module.py my_workflow --arg1 value1
# Run remotely
pyflyte --config config.yaml run --remote my_module.py my_workflow --arg1 value1
```
### Flyte 2
```shell
# Run remotely (default)
flyte run my_module.py my_task --arg1 value1
# Run locally
flyte run --local my_module.py my_task --arg1 value1
# With explicit config
flyte --config config.yaml run my_module.py my_task --arg1 value1
```
### Key CLI flag differences
| Flyte 1 flag | Flyte 2 flag | Notes |
|---------|---------|-------|
| `--remote` | (default) | Remote is default in Flyte 2 |
| `--copy-all` | `--copy-style all` | File copying |
| N/A | `--copy-style loaded_modules` | Default: only imported modules |
| N/A | `--copy-style none` | Don't copy files |
| `-p, --project` | `--project` | Same |
| `-d, --domain` | `--domain` | Same |
| `-i, --image` | `--image` | Same format |
| N/A | `--follow, -f` | Follow execution logs |
### Deploying
### Flyte 1
```shell
pyflyte register my_module.py -p my-project -d development
```
### Flyte 2
```shell
# Deploy task environments
flyte deploy my_module.py my_env --project my-project --domain development
# Deploy all environments in file
flyte deploy --all my_module.py
# Deploy with version
flyte deploy --version v1.0.0 my_module.py my_env
# Recursive deployment
flyte deploy --recursive --all ./src
# Dry run (preview)
flyte deploy --dry-run my_module.py my_env
```
### Running deployed tasks
```shell
# Run a deployed task
flyte run deployed-task my_env.my_task --arg1 value1
# Run specific version
flyte run deployed-task my_env.my_task:v1.0.0 --arg1 value1
```
### Complete Flyte 2 CLI options
```shell
# Global options
flyte --endpoint # Override endpoint
flyte --config # Config file path
flyte --org # Organization
flyte -v, --verbose # Verbose output (can repeat: -vvv)
flyte --output-format [table|json] # Output format
# Run command options
flyte run [OPTIONS] [TASK_ARGS]
--local # Run locally
--project # Project
--domain # Domain
--copy-style [loaded_modules|all|none] # File copying
--root-dir # Source root directory
--follow, -f # Follow logs
--image [NAME=]URI # Image override
--name # Execution name
--service-account # K8s service account
# Deploy command options
flyte deploy [OPTIONS] [ENV_NAME]
--project # Project
--domain # Domain
--version # Version
--dry-run # Preview without deploying
--copy-style [loaded_modules|all|none] # File copying
--recursive, -r # Deploy recursively
--all # Deploy all environments
--image [NAME=]URI # Image override
```
For full CLI reference, see [Flyte CLI](../flyte-cli).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/tasks-and-workflows ===
# Tasks and workflows
## Basic task migration
### Flyte 1
```python
from flytekit import task, Resources
@task(
cache=True,
cache_version="1.0",
retries=3,
timeout=3600,
container_image="python:3.11",
requests=Resources(cpu="1", mem="2Gi"),
limits=Resources(cpu="2", mem="4Gi"),
)
def my_task(x: int) -> int:
return x * 2
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
image="python:3.11",
resources=flyte.Resources(cpu="1", memory="2Gi"),
cache="auto",
)
@env.task(retries=3, timeout=3600)
def my_task(x: int) -> int:
return x * 2
```
## Workflow to task migration
In Flyte 2 there is no `@workflow` decorator. Workflows are tasks that call other tasks.
### Flyte 1
```python
from flytekit import task, workflow
@task
def step1(x: int) -> int:
return x + 1
@task
def step2(y: int) -> int:
return y * 2
@task
def step3(z: int) -> str:
return f"Result: {z}"
@workflow
def my_workflow(x: int) -> str:
a = step1(x=x)
b = step2(y=a)
c = step3(z=b)
return c
```
### Flyte 2 Sync
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def step1(x: int) -> int:
return x + 1
@env.task
def step2(y: int) -> int:
return y * 2
@env.task
def step3(z: int) -> str:
return f"Result: {z}"
@env.task
def main(x: int) -> str:
a = step1(x)
b = step2(a)
c = step3(b)
return c
```
### Flyte 2 Async
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def step1(x: int) -> int:
return x + 1
@env.task
async def step2(y: int) -> int:
return y * 2
@env.task
async def step3(z: int) -> str:
return f"Result: {z}"
@env.task
async def main(x: int) -> str:
a = await step1(x)
b = await step2(a)
c = await step3(b)
return c
```
> **π Note**
>
> You can only `await` async tasks. If you try to `await` a sync task, it will fail. If your subtasks are sync, call them directly without `await` (they will execute synchronously/sequentially).
## TaskEnvironment configuration
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env", # Required: unique name
image=flyte.Image.from_debian_base(...), # Or string, or "auto"
resources=flyte.Resources(
cpu="2",
memory="4Gi",
gpu="A100:1",
disk="10Gi",
shm="auto",
),
env_vars={"LOG_LEVEL": "INFO"},
secrets=[
flyte.Secret(key="api-key", as_env_var="API_KEY"),
],
cache="auto", # "auto", "override", "disable", or Cache object
reusable=flyte.ReusePolicy(replicas=5, idle_ttl=60),
queue="gpu-queue",
interruptible=True,
)
# Task decorator can override some settings
@env.task(
short_name="my_task", # Display name
cache="disable", # Override cache
retries=3, # Retry count
timeout=3600, # Seconds or timedelta
report=True, # Generate HTML report
)
def my_task(x: int) -> int:
return x
```
## Parameter mapping: @task to TaskEnvironment + @env.task
| Flyte 1 `@task` parameter | Flyte 2 location | Notes |
|--------------------|-------------|-------|
| `container_image` | `TaskEnvironment(image=...)` | Env-level only |
| `requests` | `TaskEnvironment(resources=...)` | Env-level only |
| `limits` | `TaskEnvironment(resources=...)` | Combined with requests |
| `environment` | `TaskEnvironment(env_vars=...)` | Env-level only |
| `secret_requests` | `TaskEnvironment(secrets=...)` | Env-level only |
| `cache` | Both | Can override at task level |
| `cache_version` | `flyte.Cache(version_override=...)` | In Cache object |
| `retries` | `@env.task(retries=...)` | Task-level only |
| `timeout` | `@env.task(timeout=...)` | Task-level only |
| `interruptible` | Both | Can override at task level |
| `pod_template` | Both | Can override at task level |
| `deprecated` | N/A | Not in Flyte 2 |
| `docs` | `@env.task(docs=...)` | Task-level only |
For full details, see [Configure tasks](../../user-guide/task-configuration/_index).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/secrets-resources-caching ===
# Secrets, resources, and caching
## Secrets
### Declaring and accessing secrets
### Flyte 1
```python
from flytekit import task, Secret, current_context
@task(secret_requests=[
Secret(group="mygroup", key="mykey"),
Secret(group="db", key="password", mount_requirement=Secret.MountType.ENV_VAR),
])
def my_task() -> str:
ctx = current_context()
secret_value = ctx.secrets.get(key="mykey", group="mygroup")
db_password = ctx.secrets.get(key="password", group="db")
return f"Got secrets"
```
### Flyte 2
```python
import flyte
import os
env = flyte.TaskEnvironment(
name="my_env",
secrets=[
flyte.Secret(key="mykey", as_env_var="MY_SECRET"),
flyte.Secret(key="db-password", as_env_var="DB_PASSWORD"),
],
)
@env.task
def my_task() -> str:
secret_value = os.environ["MY_SECRET"]
db_password = os.environ["DB_PASSWORD"]
return f"Got secrets"
```
### Secret configuration options
```python
# Flyte 2 Secret options
flyte.Secret(
key="secret-name", # Required: secret key in store
group="optional-group", # Optional: organizational group
as_env_var="CUSTOM_ENV_VAR_NAME", # Mount as this env var name
# OR
mount="/etc/flyte/secrets", # Mount as file (fixed path)
)
# Examples
secrets=[
# Simple: key becomes uppercase env var (MY_API_KEY)
flyte.Secret(key="my-api-key"),
# Custom env var name
flyte.Secret(key="openai-key", as_env_var="OPENAI_API_KEY"),
# With group (env var: AWS_ACCESS_KEY)
flyte.Secret(key="access-key", group="aws"),
# As file
flyte.Secret(key="ssl-cert", mount="/etc/flyte/secrets"),
]
```
### Secret name convention changes
| Flyte 1 pattern | Flyte 2 pattern |
|------------|------------|
| `ctx.secrets.get(key="mykey", group="mygroup")` | `os.environ["MYGROUP_MYKEY"]` (auto-named) |
| `ctx.secrets.get(key="mykey")` | `os.environ["MY_SECRET"]` (with `as_env_var`) |
### Creating secrets via CLI
```bash
# Create secret
flyte create secret MY_SECRET_KEY my_secret_value
# From file
flyte create secret MY_SECRET_KEY --from-file /path/to/secret
# Scoped to project/domain
flyte create secret --project my-project --domain development MY_SECRET_KEY value
# List secrets
flyte get secret
# Delete secret
flyte delete secret MY_SECRET_KEY
```
For full details on secrets, see [Secrets](../../user-guide/task-configuration/secrets).
## Resources
### Basic resource configuration
### Flyte 1
```python
from flytekit import task, Resources
# Separate requests and limits
@task(
requests=Resources(cpu="1", mem="2Gi"),
limits=Resources(cpu="2", mem="4Gi"),
)
def my_task(): ...
# Unified resources (tuple for request/limit)
@task(resources=Resources(cpu=("1", "2"), mem="2Gi"))
def my_task(): ...
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
resources=flyte.Resources(
cpu="2", # Request and limit same
memory="4Gi", # Note: "memory" not "mem"
gpu="A100:1", # GPU type and count
disk="10Gi",
shm="auto", # Shared memory
),
)
```
### GPU configuration
### Flyte 1
```python
from flytekit import task, Resources
from flytekit.extras.accelerators import A100
@task(
requests=Resources(gpu="1"),
accelerator=A100,
)
def gpu_task(): ...
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(
name="gpu_env",
resources=flyte.Resources(
cpu="4",
memory="32Gi",
gpu="A100:2", # Type:count format
# Or: gpu="A100 80G:1"
# Or: gpu=2 # Count only, no type
),
)
# GPU with partition (MIG)
env = flyte.TaskEnvironment(
name="mig_env",
resources=flyte.Resources(
gpu=flyte.GPU("A100", count=1, partition="1g.5gb"),
),
)
```
### Supported GPU types (Flyte 2)
- A10, A10G, A100, A100 80G
- B200, H100, H200
- L4, L40s
- T4, V100
- RTX PRO 6000, GB10
### Resource parameter mapping
| Flyte 1 | Flyte 2 | Notes |
|----|----| ------|
| `cpu="1"` | `cpu="1"` | Same |
| `mem="2Gi"` | `memory="2Gi"` | Renamed |
| `gpu="1"` | `gpu="A100:1"` | Type:count format |
| `ephemeral_storage="10Gi"` | `disk="10Gi"` | Renamed |
| N/A | `shm="auto"` | New: shared memory |
For full details on resources, see [Resources](../../user-guide/task-configuration/resources).
## Caching
### Basic caching
### Flyte 1
```python
from flytekit import task, Cache
@task(cache=True, cache_version="1.0")
def cached_task(x: int) -> int:
return x * 2
# With Cache object
@task(cache=Cache(
version="1.0",
serialize=True,
ignored_inputs=("debug",),
))
def advanced_cached_task(x: int, debug: bool = False) -> int:
return x * 2
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(
name="my_env",
cache="auto", # Enable caching at env level
)
@env.task
def cached_task(x: int) -> int:
return x * 2
# Override at task level
@env.task(cache="disable")
def uncached_task(x: int) -> int:
return x * 2
# Advanced caching
@env.task(cache=flyte.Cache(
behavior="auto", # "auto", "override", "disable"
version_override="v1.0", # Explicit version
serialize=True, # Force serial execution
ignored_inputs=("debug",), # Exclude from hash
salt="my-salt", # Additional hash salt
))
def advanced_cached_task(x: int, debug: bool = False) -> int:
return x * 2
```
### Cache behavior options (Flyte 2)
| Behavior | Description |
|----------|-------------|
| `"auto"` | Cache results and reuse if available |
| `"override"` | Always execute and overwrite cache |
| `"disable"` | No caching (default for TaskEnvironment) |
For full details on caching, see [Caching](../../user-guide/task-configuration/caching).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/parallelism-and-async ===
# Parallelism and async
## Basic map_task migration
### Flyte 1
```python
from flytekit import task, workflow, map_task
@task
def process_item(x: int) -> int:
return x * 2
@workflow
def my_workflow(items: list[int]) -> list[int]:
return map_task(process_item)(x=items)
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def process_item(x: int) -> int:
return x * 2
@env.task
def main(items: list[int]) -> list[int]:
return list(flyte.map(process_item, items))
```
## map_task with concurrency
### Flyte 1
```python
@workflow
def my_workflow(items: list[int]) -> list[int]:
return map_task(process_item, concurrency=5)(x=items)
```
### Flyte 2
```python
@env.task
def main(items: list[int]) -> list[int]:
return list(flyte.map(process_item, items, concurrency=5))
```
## Async parallel execution with asyncio.gather
This is the recommended approach for parallel execution in Flyte 2.
```python
import asyncio
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process_item(item: int) -> str:
return f"processed_{item}"
@env.task
async def main(items: list[int]) -> list[str]:
tasks = [process_item(item) for item in items]
results = await asyncio.gather(*tasks)
return list(results)
```
## Concurrency control with semaphore
```python
import asyncio
@env.task
async def process_item(item: int) -> str:
await asyncio.sleep(1)
return f"processed_{item}"
@env.task
async def main_with_concurrency_limit(
items: list[int],
max_concurrent: int = 5
) -> list[str]:
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_limit(item: int) -> str:
async with semaphore:
return await process_item(item)
tasks = [process_with_limit(item) for item in items]
results = await asyncio.gather(*tasks)
return list(results)
```
## Error handling with asyncio.gather
```python
@env.task
async def main_with_error_handling(
items: list[int],
max_concurrent: int = 5
) -> list[str]:
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_limit(item: int) -> str:
async with semaphore:
return await process_item(item)
tasks = [process_with_limit(item) for item in items]
results = await asyncio.gather(*tasks, return_exceptions=True)
processed = []
for i, result in enumerate(results):
if isinstance(result, Exception):
print(f"Item {items[i]} failed: {result}")
processed.append(f"Failed: {items[i]}")
else:
processed.append(result)
return processed
```
## flyte.map vs asyncio.gather comparison
| Feature | flyte.map (sync) | asyncio.gather (async) |
|---------|------------------|------------------------|
| Syntax | `list(flyte.map(fn, items))` | `await asyncio.gather(*tasks)` |
| Concurrency limit | Built-in `concurrency=N` | Use `asyncio.Semaphore` |
| Streaming/as-completed | No control | Yes, via `asyncio.as_completed()` |
| Error handling | `return_exceptions=True` | Check return type |
| Flexibility | Less flexible | More flexible |
## Recommended pattern selection
Use **flyte.map** when:
- You are forced to use synchronous Python
- You want minimal code changes from Flyte 1 `map_task`
Use **asyncio.gather** when (recommended):
- You want maximum flexibility and control
- You need streaming results (`asyncio.as_completed`)
- You need fine-grained concurrency control (semaphores)
- You're writing new Flyte 2 code
## Sync and async task patterns
Keep task types consistent within a call chain for clarity and predictability.
### Sync tasks calling sync tasks
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def step1(x: int) -> int:
return x + 1
@env.task
def step2(y: int) -> int:
return y * 2
@env.task
def main(x: int) -> int:
a = step1(x) # Runs, returns result
b = step2(a) # Runs after step1 completes
return b
```
### Async tasks calling async tasks
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def step1(x: int) -> int:
return x + 1
@env.task
async def step2(y: int) -> int:
return y * 2
@env.task
async def main(x: int) -> int:
a = await step1(x) # Runs, waits for result
b = await step2(a) # Runs after step1 completes
return b
```
### Sequential execution with await
When you `await` async tasks one after another, they run sequentially, just like Flyte 1 workflows:
### Flyte 1
```python
@workflow
def my_workflow(x: int) -> str:
a = step1(x=x) # Runs first
b = step2(y=a) # Runs second
c = step3(z=b) # Runs third
return c
```
### Flyte 2
```python
@env.task
async def main(x: int) -> str:
a = await step1(x) # Runs first
b = await step2(a) # Runs second
c = await step3(b) # Runs third
return c
```
> **π Note**
>
> `await` means "wait for this to finish before continuing." Sequential `await` calls behave the same as sequential task calls in Flyte 1 workflows.
For full details on async patterns, see [Asynchronous model](../../user-guide/flyte-2/async).
For full details on parallel fanout, see [Fanout](../../user-guide/task-programming/fanout).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/triggers-and-dynamic ===
# Triggers and dynamic workflows
## LaunchPlan to Trigger migration
### Flyte 1
```python
from flytekit import workflow, LaunchPlan, CronSchedule, FixedRate
from datetime import timedelta
@workflow
def my_workflow(x: int) -> int:
return process(x)
# Cron schedule
cron_lp = LaunchPlan.get_or_create(
workflow=my_workflow,
name="hourly_run",
default_inputs={"x": 10},
schedule=CronSchedule(schedule="0 * * * *"),
)
# Fixed rate
rate_lp = LaunchPlan.get_or_create(
workflow=my_workflow,
name="frequent_run",
default_inputs={"x": 5},
schedule=FixedRate(duration=timedelta(minutes=30)),
)
```
### Flyte 2
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
# Hourly trigger (convenience method)
@env.task(triggers=flyte.Trigger.hourly())
def hourly_task(x: int = 10) -> int:
return process(x)
# Custom cron trigger
cron_trigger = flyte.Trigger(
name="custom_cron",
automation=flyte.Cron("0 * * * *"),
inputs={"x": 10},
auto_activate=True,
)
@env.task(triggers=cron_trigger)
def scheduled_task(x: int) -> int:
return process(x)
# Fixed rate trigger
rate_trigger = flyte.Trigger(
name="frequent",
automation=flyte.FixedRate(timedelta(minutes=30)),
inputs={"x": 5},
auto_activate=True,
)
@env.task(triggers=rate_trigger)
def frequent_task(x: int) -> int:
return process(x)
```
## Trigger options
```python
# Convenience methods
flyte.Trigger.hourly() # Every hour
flyte.Trigger.hourly("my_time") # Custom time parameter name
flyte.Trigger.minutely() # Every minute
# Custom Trigger
flyte.Trigger(
name="my_trigger", # Required: trigger name
automation=flyte.Cron(...), # Cron or FixedRate
inputs={"x": 10}, # Default inputs
auto_activate=True, # Activate on deploy
)
# Cron options
flyte.Cron(
schedule="0 9 * * 1-5", # 9 AM weekdays
timezone="America/New_York", # Optional timezone
)
# FixedRate options
flyte.FixedRate(timedelta(hours=1)) # Every hour
```
## Deploying triggers
```bash
# Deploy environment (triggers deploy with it)
flyte deploy my_module.py my_env
# Triggers with auto_activate=True activate automatically
# Otherwise, activate manually via UI or API
```
For full details on triggers, see [Triggers](../../user-guide/task-configuration/triggers).
## Dynamic workflows
In Flyte 1, `@dynamic` was needed for tasks that generate variable numbers of subtask calls at runtime. In Flyte 2, all tasks can have dynamic behavior natively.
### @dynamic to regular tasks
### Flyte 1
```python
from flytekit import dynamic, task, workflow
@task
def get_tiles(n: int) -> list[int]:
return list(range(n))
@task
def process_tile(tile: int) -> int:
return tile * 2
@dynamic
def process_all_tiles(tiles: list[int]) -> list[int]:
results = []
for tile in tiles:
results.append(process_tile(tile=tile))
return results
@workflow
def main_workflow(n: int) -> list[int]:
tiles = get_tiles(n=n)
return process_all_tiles(tiles=tiles)
```
### Flyte 2 Sync
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
def process_tile(tile: int) -> int:
return tile * 2
@env.task
def process_all_tiles(tiles: list[int]) -> list[int]:
results = []
for tile in tiles:
results.append(process_tile(tile))
return results
@env.task
def main(n: int) -> list[int]:
tiles = list(range(n))
return process_all_tiles(tiles)
```
### Flyte 2 Async
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process_tile(tile: int) -> int:
return tile * 2
@env.task
async def process_all_tiles(tiles: list[int]) -> list[int]:
results = []
for tile in tiles:
results.append(await process_tile(tile))
return results
@env.task
async def main(n: int) -> list[int]:
tiles = list(range(n))
return await process_all_tiles(tiles)
```
## Conditional execution
### Flyte 1
```python
from flytekit import conditional
@workflow
def conditional_wf(x: int) -> int:
return (
conditional("test")
.if_(x > 0)
.then(positive_task(x=x))
.else_()
.then(negative_task(x=x))
)
```
### Flyte 2
```python
@env.task
def main(x: int) -> int:
if x > 0:
return positive_task(x)
else:
return negative_task(x)
```
## Subworkflows to nested tasks
### Flyte 1
```python
@workflow
def sub_workflow(x: int) -> int:
a = step1(x)
b = step2(a)
return b
@workflow
def main_workflow(item: int) -> int:
result = sub_workflow(x=item)
return result
```
### Flyte 2
```python
@env.task
def sub_task(x: int) -> int:
a = step1(x)
b = step2(a)
return b
@env.task
def main(item: int) -> int:
result = sub_task(item)
return result
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/migration/examples-and-gotchas ===
# Examples and common gotchas
## Complete migration examples
### Example 1: Simple ML pipeline
### Flyte 1
```python
from flytekit import task, workflow, ImageSpec, Resources, current_context
from flytekit.types.file import FlyteFile
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import joblib
import os
image = ImageSpec(
name="ml-image",
packages=["pandas", "scikit-learn", "joblib", "pyarrow"],
builder="union",
)
@task(
container_image=image,
requests=Resources(cpu="2", mem="4Gi"),
cache=True,
cache_version="1.1",
)
def load_data() -> pd.DataFrame:
CSV_URL = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv"
return pd.read_csv(CSV_URL)
@task(container_image=image)
def train_model(data: pd.DataFrame) -> FlyteFile:
model = RandomForestClassifier()
X = data.drop("species", axis=1)
y = data["species"]
model.fit(X, y)
model_path = os.path.join(current_context().working_directory, "model.joblib")
joblib.dump(model, model_path)
return FlyteFile(path=model_path)
@task(container_image=image)
def evaluate(model_file: FlyteFile, data: pd.DataFrame) -> float:
model = joblib.load(model_file.download())
X = data.drop("species", axis=1)
y = data["species"]
return float(model.score(X, y))
@workflow
def ml_pipeline() -> float:
data = load_data()
model = train_model(data=data)
score = evaluate(model_file=model, data=data)
return score
```
### Flyte 2
```python
import os
import joblib
import pandas as pd
import flyte
from flyte import TaskEnvironment, Resources, Image
from flyte.io import File
from sklearn.ensemble import RandomForestClassifier
# 1. Define the Image using the fluent builder API
image = (
Image.from_debian_base(
name="ml-image",
python_version=(3, 11),
)
.with_pip_packages("pandas", "scikit-learn", "joblib", "pyarrow")
)
# 2. Define the TaskEnvironment
env = TaskEnvironment(
name="ml_env",
image=image,
resources=Resources(cpu="2", memory="4Gi"),
cache="auto",
)
@env.task
async def load_data() -> pd.DataFrame:
CSV_URL = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv"
return pd.read_csv(CSV_URL)
@env.task
async def train_model(data: pd.DataFrame) -> File:
model = RandomForestClassifier()
X = data.drop("species", axis=1)
y = data["species"]
model.fit(X, y)
root_dir = os.getcwd()
model_path = os.path.join(root_dir, "model.joblib")
joblib.dump(model, model_path)
return await File.from_local(model_path)
@env.task
async def evaluate(model_file: File, data: pd.DataFrame) -> float:
local_path = await model_file.download()
model = joblib.load(local_path)
X = data.drop("species", axis=1)
y = data["species"]
return float(model.score(X, y))
# 3. The workflow is now just an orchestrating task
@env.task
async def ml_pipeline() -> float:
data = await load_data()
model = await train_model(data)
score = await evaluate(model, data)
return score
```
### Example 2: Parallel processing with map_task
### Flyte 1
```python
from flytekit import task, workflow, map_task, dynamic
from functools import partial
@task(cache=True, cache_version="1.0")
def get_chunks(n: int) -> list[int]:
return list(range(n))
@task(cache=True, cache_version="1.0")
def process_chunk(chunk_id: int, multiplier: int) -> int:
return chunk_id * multiplier
@workflow
def parallel_pipeline(n: int, multiplier: int) -> list[int]:
chunk_ids = get_chunks(n)
results = map_task(
partial(process_chunk, multiplier=multiplier),
concurrency=10,
)(chunk_id=chunk_ids)
return results
```
### Flyte 2 Sync
```python
from functools import partial
import flyte
env = flyte.TaskEnvironment(name="parallel_env", cache="auto")
@env.task
def process_chunk(chunk_id: int, multiplier: int) -> int:
return chunk_id * multiplier
@env.task
def main(n: int, multiplier: int) -> list[int]:
chunk_ids = list(range(n))
bound_task = partial(process_chunk, multiplier=multiplier)
results = list(flyte.map(bound_task, chunk_ids, concurrency=10))
return results
```
### Flyte 2 Async
```python
import asyncio
import flyte
env = flyte.TaskEnvironment(name="parallel_env", cache="auto")
@env.task
async def process_chunk(chunk_id: int, multiplier: int) -> int:
return chunk_id * multiplier
@env.task
async def main(n: int, multiplier: int) -> list[int]:
chunk_ids = list(range(n))
sem = asyncio.Semaphore(10)
async def throttled_task(cid):
async with sem:
return await process_chunk(cid, multiplier)
tasks = [throttled_task(cid) for cid in chunk_ids]
results = await asyncio.gather(*tasks)
return list(results)
```
## Common gotchas
### 1. current_context() is replaced
```python
# Flyte 1
ctx = flytekit.current_context()
secret = ctx.secrets.get(key="mykey", group="mygroup")
working_dir = ctx.working_directory
execution_id = ctx.execution_id
# Flyte 2 - Secrets via environment variables
secret = os.environ["MYGROUP_MYKEY"]
# Flyte 2 - Context via flyte.ctx()
ctx = flyte.ctx()
```
### 2. Workflow >> ordering notation is gone
```python
# Flyte 1: Using >> to indicate ordering
@workflow
def my_workflow():
t1_result = task1()
t2_result = task2()
t1_result >> t2_result
return t2_result
# Flyte 2: Sequential calls are naturally ordered (sync)
@env.task
def main():
t1_result = task1() # Runs first
t2_result = task2() # Runs second
return t2_result
# Flyte 2: For async, use await to sequence
@env.task
async def main():
t1_result = await task1() # Runs first
t2_result = await task2() # Runs second
return t2_result
```
### 3. flyte.map returns a generator
```python
# Flyte 1: map_task returns list directly
results = map_task(my_task)(items=my_list)
# Flyte 2: flyte.map returns generator - must convert to list
results = list(flyte.map(my_task, my_list)) # Add list()!
# Flyte 2 async: Use asyncio.gather for async parallel execution
tasks = [my_task(item) for item in my_list]
results = await asyncio.gather(*tasks)
```
### 4. Image configuration location
```python
# Flyte 1: Image per task
@task(container_image=my_image)
def task1(): ...
@task(container_image=my_image) # Repeated
def task2(): ...
# Flyte 2: Image at TaskEnvironment level (DRY)
env = flyte.TaskEnvironment(name="my_env", image=my_image)
@env.task
def task1(): ... # Uses env's image
@env.task
def task2(): ... # Uses env's image
```
### 5. Resource configuration
```python
# Flyte 1: Resources per task
@task(requests=Resources(cpu="1"), limits=Resources(cpu="2"))
def my_task(): ...
# Flyte 2: Resources at TaskEnvironment level
env = flyte.TaskEnvironment(
name="my_env",
resources=flyte.Resources(cpu="1"), # No separate limits
)
```
### 6. Cache version
```python
# Flyte 1: Explicit cache_version required
@task(cache=True, cache_version="1.0")
def my_task(): ...
# Flyte 2: Auto-versioning or explicit
@env.task(cache="auto") # Auto-versioned
def my_task(): ...
@env.task(cache=flyte.Cache(behavior="auto", version_override="1.0"))
def my_task_explicit(): ...
```
### 7. Entrypoint task naming
```python
# Flyte 1: Workflow is the entrypoint
@workflow
def my_workflow(): ...
# Flyte 2: Use a main() task or any task name
@env.task
def main(): ... # Common convention
# Run with: flyte run my_module.py main
```
### 8. Memory parameter name
```python
# Flyte 1
Resources(mem="2Gi")
# Flyte 2
flyte.Resources(memory="2Gi") # Note: "memory" not "mem"
```
### 9. Type annotations
```python
# Flyte 1: Strict about type annotations
@task
def my_task(x: int) -> dict: # Would fail, need dict[str, int]
return {"a": x}
# Flyte 2: More lenient
@env.task
def my_task(x: int) -> dict: # Works, v2 will pickle untyped I/O
return {"a": x}
```
## Quick reference cheat sheet
```python
# FLYTE 2 MINIMAL TEMPLATE
import flyte
import asyncio
# 1. Define image
image = (
flyte.Image.from_debian_base(python_version=(3, 11))
.with_pip_packages("pandas", "numpy")
)
# 2. Create TaskEnvironment
env = flyte.TaskEnvironment(
name="my_env",
image=image,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
# 3. Define tasks
@env.task
async def process(x: int) -> int:
return x * 2
# 4. Define main entrypoint
@env.task
async def main(items: list[int]) -> list[int]:
tasks = [process(x) for x in items]
results = await asyncio.gather(*tasks)
return list(results)
# 5. Run
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, items=[1, 2, 3, 4, 5])
run.wait()
```
```bash
# CLI COMMANDS
flyte run my_module.py main --items '[1,2,3,4,5]'
flyte run --local my_module.py main --items '[1,2,3,4,5]'
flyte deploy my_module.py my_env
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-cli ===
# Flyte CLI
This is the command line interface for Flyte.
| Object | Action |
| ------ | -- |
| `action` | **Flyte CLI > flyte > flyte abort > flyte abort action**, **Flyte CLI > flyte > flyte get > flyte get action** |
| `run` | **Flyte CLI > flyte > flyte abort > flyte abort run**, **Flyte CLI > flyte > flyte get > flyte get run** |
| `config` | **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte get > flyte get config** |
| `project` | **Flyte CLI > flyte > flyte create > flyte create project**, **Flyte CLI > flyte > flyte get > flyte get project**, **Flyte CLI > flyte > flyte update > flyte update project** |
| `secret` | **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte get > flyte get secret** |
| `trigger` | **Flyte CLI > flyte > flyte create > flyte create trigger**, **Flyte CLI > flyte > flyte delete > flyte delete trigger**, **Flyte CLI > flyte > flyte get > flyte get trigger**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| `app` | **Flyte CLI > flyte > flyte delete > flyte delete app**, **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte update > flyte update app** |
| `docs` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `io` | **Flyte CLI > flyte > flyte get > flyte get io** |
| `logs` | **Flyte CLI > flyte > flyte get > flyte get logs** |
| `task` | **Flyte CLI > flyte > flyte get > flyte get task** |
| `hf-model` | **Flyte CLI > flyte > flyte prefetch > flyte prefetch hf-model** |
| `deployed-task` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
| `tui` | **Flyte CLI > flyte > flyte start > flyte start tui** |
| Action | On |
| ------ | -- |
| `abort` | **Flyte CLI > flyte > flyte abort > flyte abort action**, **Flyte CLI > flyte > flyte abort > flyte abort run** |
| **Flyte CLI > flyte > flyte build** | - |
| `create` | **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte create > flyte create project**, **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte create > flyte create trigger** |
| `delete` | **Flyte CLI > flyte > flyte delete > flyte delete app**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte delete > flyte delete trigger** |
| **Flyte CLI > flyte > flyte deploy** | - |
| `gen` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `get` | **Flyte CLI > flyte > flyte get > flyte get action**, **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte get > flyte get config**, **Flyte CLI > flyte > flyte get > flyte get io**, **Flyte CLI > flyte > flyte get > flyte get logs**, **Flyte CLI > flyte > flyte get > flyte get project**, **Flyte CLI > flyte > flyte get > flyte get run**, **Flyte CLI > flyte > flyte get > flyte get secret**, **Flyte CLI > flyte > flyte get > flyte get task**, **Flyte CLI > flyte > flyte get > flyte get trigger** |
| `prefetch` | **Flyte CLI > flyte > flyte prefetch > flyte prefetch hf-model** |
| `run` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
| **Flyte CLI > flyte > flyte serve** | - |
| `start` | **Flyte CLI > flyte > flyte start > flyte start tui** |
| `update` | **Flyte CLI > flyte > flyte update > flyte update app**, **Flyte CLI > flyte > flyte update > flyte update project**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| **Flyte CLI > flyte > flyte whoami** | - |
## flyte
**`flyte [OPTIONS] COMMAND [ARGS]...`**
The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
It follows a simple verb/noun structure,
where the top-level commands are verbs that describe the action to be taken,
and the subcommands are nouns that describe the object of the action.
The root command can be used to configure the CLI for persistent settings,
such as the endpoint, organization, and verbosity level.
Set endpoint and organization:
```bash
$ flyte --endpoint --org get project
```
Increase verbosity level (This is useful for debugging,
this will show more logs and exception traces):
```bash
$ flyte -vvv get logs
```
Override the default config file:
```bash
$ flyte --config /path/to/config.yaml run ...
```
* [Documentation](https://www.union.ai/docs/flyte/user-guide/)
* [GitHub](https://github.com/flyteorg/flyte): Please leave a star if you like Flyte!
* [Slack](https://slack.flyte.org): Join the community and ask questions.
* [Issues](https://github.com/flyteorg/flyte/issues)
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--version` | `boolean` | `False` | Show the version and exit. |
| `--endpoint` | `text` | `Sentinel.UNSET` | The endpoint to connect to. This will override any configuration file and simply use `pkce` to connect. |
| `--insecure` | `boolean` | | Use an insecure connection to the endpoint. If not specified, the CLI will use TLS. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `-v`
`--verbose` | `integer` | `0` | Show verbose messages and exception traces. Repeating multiple times increases the verbosity (e.g., -vvv). |
| `--org` | `text` | `Sentinel.UNSET` | The organization to which the command applies. |
| `-c`
`--config` | `file` | `Sentinel.UNSET` | Path to the configuration file to use. If not specified, the default configuration file is used. |
| `--output-format`
`-of` | `choice` | `table` | Output format for commands that support it. Defaults to 'table'. |
| `--log-format` | `choice` | `console` | Formatting for logs, defaults to 'console' which is meant to be human readable. 'json' is meant for machine parsing. |
| `--reset-root-logger` | `boolean` | `False` | If set, the root logger will be reset to use Flyte logging style |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte abort
**`flyte abort COMMAND [ARGS]...`**
Abort an ongoing process.
#### flyte abort action
**`flyte abort action [OPTIONS] RUN_NAME ACTION_NAME`**
Abort an action associated with a run.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--reason` | `text` | `Manually aborted from the CLI` | The reason to abort the run. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte abort run
**`flyte abort run [OPTIONS] RUN_NAME`**
Abort a run.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--reason` | `text` | `Manually aborted from the CLI` | The reason to abort the run. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte build
**`flyte build [OPTIONS] COMMAND [ARGS]...`**
Build the environments defined in a python file or directory. This will build the images associated with the
environments.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--noop` | `boolean` | `Sentinel.UNSET` | Dummy parameter, placeholder for future use. Does not affect the build process. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte create
**`flyte create COMMAND [ARGS]...`**
Create resources in a Flyte deployment.
#### flyte create config
**`flyte create config [OPTIONS]`**
Creates a configuration file for Flyte CLI.
If the `--output` option is not specified, it will create a file named `config.yaml` in the current directory.
If the file already exists, it will raise an error unless the `--force` option is used.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--endpoint` | `text` | `Sentinel.UNSET` | Endpoint of the Flyte backend. |
| `--insecure` | `boolean` | `False` | Use an insecure connection to the Flyte backend. |
| `--org` | `text` | `Sentinel.UNSET` | Organization to use. This will override the organization in the configuration file. |
| `-o`
`--output` | `path` | `.flyte/config.yaml` | Path to the output directory where the configuration will be saved. Defaults to current directory. |
| `--force` | `boolean` | `False` | Force overwrite of the configuration file if it already exists. |
| `--image-builder`
`--builder` | `choice` | `local` | Image builder to use for building images. Defaults to 'local'. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `--local-persistence` | `boolean` | `False` | Enable SQLite persistence for local run metadata, allowing past runs to be browsed via 'flyte start tui'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create project
**`flyte create project [OPTIONS]`**
Create a new project.
Example usage:
```bash
flyte create project --id my_project_id --name "My Project"
flyte create project --id my_project_id --name "My Project" --description "My project" -l team=ml -l env=prod
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--id` | `text` | `Sentinel.UNSET` | Unique identifier for the project (immutable). |
| `--name` | `text` | `Sentinel.UNSET` | Display name for the project. |
| `--description` | `text` | `` | Description for the project. |
| `--label`
`-l` | `text` | `Sentinel.UNSET` | Labels as key=value pairs. Can be specified multiple times. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create secret
**`flyte create secret [OPTIONS] NAME`**
Create a new secret. The name of the secret is required. For example:
CODE4
If you don't provide a `--value` flag, you will be prompted to enter the
secret value in the terminal.
CODE5
If `--from-file` is specified, the value will be read from the file instead of being provided directly:
CODE6
The `--type` option can be used to create specific types of secrets.
Either `regular` or `image_pull` can be specified.
Secrets intended to access container images should be specified as `image_pull`.
Other secrets should be specified as `regular`.
If no type is specified, `regular` is assumed.
For image pull secrets, you have several options:
1. Interactive mode (prompts for registry, username, password):
CODE7
2. With explicit credentials:
CODE8
3. Lastly, you can create a secret from your existing Docker installation (i.e., you've run `docker login` in
the past) and you just want to pull from those credentials. Since you may have logged in to multiple registries,
you can specify which registries to include. If no registries are specified, all registries are added.
CODE9
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--value` | `text` | `Sentinel.UNSET` | Secret value Mutually exclusive with from_file, from_docker_config, registry. |
| `--from-file` | `path` | `Sentinel.UNSET` | Path to the file with the binary secret. Mutually exclusive with value, from_docker_config, registry. |
| `--type` | `choice` | `regular` | Type of the secret. |
| `--from-docker-config` | `boolean` | `False` | Create image pull secret from Docker config file (only for --type image_pull). Mutually exclusive with value, from_file, registry, username, password. |
| `--docker-config-path` | `path` | `Sentinel.UNSET` | Path to Docker config file (defaults to ~/.docker/config.json or $DOCKER_CONFIG). Requires from_docker_config. |
| `--registries` | `text` | `Sentinel.UNSET` | Comma-separated list of registries to include (only with --from-docker-config). |
| `--registry` | `text` | `Sentinel.UNSET` | Registry hostname (e.g., ghcr.io, docker.io) for explicit credentials (only for --type image_pull). Mutually exclusive with value, from_file, from_docker_config. |
| `--username` | `text` | `Sentinel.UNSET` | Username for the registry (only with --registry). |
| `--password` | `text` | `Sentinel.UNSET` | Password for the registry (only with --registry). If not provided, will prompt. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create trigger
**`flyte create trigger [OPTIONS] TASK_NAME NAME`**
Create a new trigger for a task. The task name and trigger name are required.
Example:
CODE10
This will create a trigger that runs every day at midnight.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--schedule` | `text` | `Sentinel.UNSET` | Cron schedule for the trigger. Defaults to every minute. |
| `--description` | `text` | `` | Description of the trigger. |
| `--auto-activate` | `boolean` | `True` | Whether the trigger should not be automatically activated. Defaults to True. |
| `--trigger-time-var` | `text` | `trigger_time` | Variable name for the trigger time in the task inputs. Defaults to 'trigger_time'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte delete
**`flyte delete COMMAND [ARGS]...`**
Remove resources from a Flyte deployment.
#### flyte delete app
**`flyte delete app [OPTIONS] NAME`**
Delete apps from a Flyte deployment.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete secret
**`flyte delete secret [OPTIONS] NAME`**
Delete a secret. The name of the secret is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete trigger
**`flyte delete trigger [OPTIONS] NAME TASK_NAME`**
Delete a trigger. The name of the trigger is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte deploy
**`flyte deploy [OPTIONS] COMMAND [ARGS]...`**
Deploy one or more environments from a python file.
This command will create or update environments in the Flyte system, registering
all tasks and their dependencies.
Example usage:
CODE11
Arguments to the deploy command are provided right after the `deploy` command and before the file name.
To deploy all environments in a file, use the `--all` flag:
CODE12
To recursively deploy all environments in a directory and its subdirectories, use the `--recursive` flag:
CODE13
You can combine `--all` and `--recursive` to deploy everything:
CODE14
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
CODE15
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
CODE16
You can specify multiple image arguments:
CODE17
To deploy a specific version, use the `--version` flag:
CODE18
To preview what would be deployed without actually deploying, use the `--dry-run` flag:
CODE19
You can specify the `--config` flag to point to a specific Flyte cluster:
CODE20
You can override the default configured project and domain:
CODE21
If loading some files fails during recursive deployment, you can use the `--ignore-load-errors` flag
to continue deploying the environments that loaded successfully:
CODE22
Other arguments to the deploy command are listed below.
To see the environments available in a file, use `--help` after the file name:
CODE23
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--version` | `text` | `Sentinel.UNSET` | Version of the environment to deploy |
| `--dry-run`
`--dryrun` | `boolean` | `False` | Dry run. Do not actually call the backend service. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--recursive`
`-r` | `boolean` | `False` | Recursively deploy all environments in the current directory |
| `--all` | `boolean` | `False` | Deploy all environments in the current directory, ignoring the file name |
| `--ignore-load-errors`
`-i` | `boolean` | `False` | Ignore errors when loading environments especially when using --recursive or --all. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte gen
**`flyte gen COMMAND [ARGS]...`**
Generate documentation.
#### flyte gen docs
**`flyte gen docs [OPTIONS]`**
Generate documentation.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--type` | `text` | `Sentinel.UNSET` | Type of documentation (valid: markdown) |
| `--plugin-variants` | `text` | | Hugo variant names for plugin commands (e.g., 'byoc selfmanaged'). When set, plugin command sections and index entries are wrapped in {{< variant >}} shortcodes. Core commands appear unconditionally. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte get
**`flyte get COMMAND [ARGS]...`**
Retrieve resources from a Flyte deployment.
You can get information about projects, runs, tasks, actions, secrets, logs and input/output values.
Each command supports optional parameters to filter or specify the resource you want to retrieve.
Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
For example:
* `get project` (without specifying a project), will list all projects.
* `get project my_project` will return the details of the project named `my_project`.
In some cases, a partially specified command will act as a filter and return available further parameters.
For example:
* `get action my_run` will return all actions for the run named `my_run`.
* `get action my_run my_action` will return the details of the action named `my_action` for the run `my_run`.
#### flyte get action
**`flyte get action [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get all actions for a run or details for a specific action.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--in-phase` | `choice` | `Sentinel.UNSET` | Filter actions by their phase. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get app
**`flyte get app [OPTIONS] [NAME]`**
Get a list of all apps, or details of a specific app by name.
Apps are long-running services deployed on the Flyte platform.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of apps to fetch when listing. |
| `--only-mine` | `boolean` | `False` | Show only apps created by the current user (you). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get config
**`flyte get config`**
Shows the automatically detected configuration to connect with the remote backend.
The configuration will include the endpoint, organization, and other settings that are used by the CLI.
#### flyte get io
**`flyte get io [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get the inputs and outputs of a run or action.
If only the run name is provided, it will show the inputs and outputs of the root action of that run.
If an action name is provided, it will show the inputs and outputs for that action.
If `--inputs-only` or `--outputs-only` is specified, it will only show the inputs or outputs respectively.
Examples:
CODE24
CODE25
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--inputs-only`
`-i` | `boolean` | `False` | Show only inputs |
| `--outputs-only`
`-o` | `boolean` | `False` | Show only outputs |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get logs
**`flyte get logs [OPTIONS] RUN_NAME [ACTION_NAME]`**
Stream logs for the provided run or action.
If only the run is provided, only the logs for the parent action will be streamed:
CODE26
If you want to see the logs for a specific action, you can provide the action name as well:
CODE27
By default, logs will be shown in the raw format and will scroll the terminal.
If automatic scrolling and only tailing `--lines` number of lines is desired, use the `--pretty` flag:
CODE28
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--lines`
`-l` | `integer` | `30` | Number of lines to show, only useful for --pretty |
| `--show-ts` | `boolean` | `False` | Show timestamps |
| `--pretty` | `boolean` | `False` | Show logs in an auto-scrolling box, where number of lines is limited to `--lines` |
| `--attempt`
`-a` | `integer` | | Attempt number to show logs for, defaults to the latest attempt. |
| `--filter-system` | `boolean` | `False` | Filter all system logs from the output. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get project
**`flyte get project [OPTIONS] [NAME]`**
Get a list of all projects, or details of a specific project by name.
By default, only active (unarchived) projects are shown. Use `--archived` to
show archived projects instead.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--archived` | `boolean` | `False` | Show archived projects instead of active ones. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get run
**`flyte get run [OPTIONS] [NAME]`**
Get a list of all runs, or details of a specific run by name.
The run details will include information about the run, its status, but only the root action will be shown.
If you want to see the actions for a run, use `get action `.
You can filter runs by task name and optionally task version:
CODE29
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of runs to fetch when listing. |
| `--in-phase` | `choice` | `Sentinel.UNSET` | Filter runs by their status. |
| `--only-mine` | `boolean` | `False` | Show only runs created by the current user (you). |
| `--task-name` | `text` | | Filter runs by task name. |
| `--task-version` | `text` | | Filter runs by task version. |
| `--created-after` | `datetime` | | Show runs created at or after this datetime (UTC). Accepts ISO dates, 'now', 'today', or 'now - 1 day'. |
| `--created-before` | `datetime` | | Show runs created before this datetime (UTC). |
| `--updated-after` | `datetime` | | Show runs updated at or after this datetime (UTC). Accepts ISO dates, 'now', 'today', or 'now - 1 day'. |
| `--updated-before` | `datetime` | | Show runs updated before this datetime (UTC). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get secret
**`flyte get secret [OPTIONS] [NAME]`**
Get a list of all secrets, or details of a specific secret by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get task
**`flyte get task [OPTIONS] [NAME] [VERSION]`**
Retrieve a list of all tasks, or details of a specific task by name and version.
Currently, both `name` and `version` are required to get a specific task.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of tasks to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get trigger
**`flyte get trigger [OPTIONS] [TASK_NAME] [NAME]`**
Get a list of all triggers, or details of a specific trigger by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of triggers to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte prefetch
**`flyte prefetch COMMAND [ARGS]...`**
Prefetch artifacts from remote registries.
These commands help you download and prefetch artifacts like HuggingFace models
to your Flyte storage for faster access during task execution.
#### flyte prefetch hf-model
**`flyte prefetch hf-model [OPTIONS] REPO`**
Prefetch a HuggingFace model to Flyte storage.
Downloads a model from the HuggingFace Hub and prefetches it to your configured
Flyte storage backend. This is useful for:
- Pre-fetching large models before running inference tasks
- Sharding models for tensor-parallel inference
- Avoiding repeated downloads during development
**Basic Usage:**
CODE30
**With Sharding:**
Create a shard config file (shard_config.yaml):
CODE31
Then run:
CODE32
**Wait for Completion:**
CODE33
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--raw-data-path` | `text` | | Object store path to store the model. If not provided, the model will be stored using the default path generated by Flyte storage layer. |
| `--artifact-name` | `text` | | Artifact name to use for the stored model. Must only contain alphanumeric characters, underscores, and hyphens. If not provided, the repo name will be used (replacing '.' with '-'). |
| `--architecture` | `text` | `Sentinel.UNSET` | Model architecture, as given in HuggingFace config.json. |
| `--task` | `text` | `auto` | Model task, e.g., 'generate', 'classify', 'embed', 'score', etc. Refer to vLLM docs. 'auto' will try to discover this automatically. |
| `--modality` | `text` | `('text',)` | Modalities supported by the model, e.g., 'text', 'image', 'audio', 'video'. Can be specified multiple times. |
| `--format` | `text` | `Sentinel.UNSET` | Model serialization format, e.g., safetensors, onnx, torchscript, joblib, etc. |
| `--model-type` | `text` | `Sentinel.UNSET` | Model type, e.g., 'transformer', 'xgboost', 'custom', etc. For HuggingFace models, this is auto-determined from config.json['model_type']. |
| `--short-description` | `text` | `Sentinel.UNSET` | Short description of the model. |
| `--force` | `integer` | `0` | Force store of the model. Increment value (--force=1, --force=2, ...) to force a new store. |
| `--wait` | `boolean` | `False` | Wait for the model to be stored before returning. |
| `--hf-token-key` | `text` | `HF_TOKEN` | Name of the Flyte secret containing your HuggingFace token. Note: This is not the HuggingFace token itself, but the name of the secret in the Flyte secret store. |
| `--cpu` | `text` | `2` | CPU request for the prefetch task (e.g., '2', '4', '2,4' for 2-4 CPUs). |
| `--mem` | `text` | `8Gi` | Memory request for the prefetch task (e.g., '16Gi', '64Gi', '16Gi,64Gi' for 16-64GB). |
| `--gpu` | `choice` | | The gpu to use for downloading and (optionally) sharding the model. Format: '{type}:{quantity}' (e.g., 'A100:8', 'L4:1'). |
| `--disk` | `text` | `50Gi` | Disk storage request for the prefetch task (e.g., '100Gi', '500Gi'). |
| `--shm` | `text` | | Shared memory request for the prefetch task (e.g., '100Gi', 'auto'). |
| `--shard-config` | `path` | `Sentinel.UNSET` | Path to a YAML file containing sharding configuration. The file should have 'engine' (currently only 'vllm') and 'args' keys. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte run
**`flyte run [OPTIONS] COMMAND [ARGS]...`**
Run a task from a python file or deployed task.
Example usage:
CODE34
Arguments to the run command are provided right after the `run` command and before the file name.
Arguments for the task itself are provided after the task name.
To run a task locally, use the `--local` flag. This will run the task in the local environment instead of the remote
Flyte environment:
CODE35
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
CODE36
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
CODE37
You can specify multiple image arguments:
CODE38
To run tasks that you've already deployed to Flyte, use the deployed-task command:
CODE39
To run a specific version of a deployed task, use the `env.task:version` syntax:
CODE40
You can specify the `--config` flag to point to a specific Flyte cluster:
CODE41
You can override the default configured project and domain:
CODE42
You can discover what deployed tasks are available by running:
CODE43
To run an arbitrary Python script on a remote cluster (without defining a task), use `python-script`:
CODE44
You can also install extra packages and wait for completion:
CODE45
Other arguments to the run command are listed below.
Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
CODE46
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--local` | `boolean` | `False` | Run the task locally |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--raw-data-path` | `text` | `Sentinel.UNSET` | Override the output prefix used to store offloaded data types. e.g. s3://bucket/ |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the run. If not provided, a random name will be generated. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the parent action. If not provided, the CLI will exit after successfully launching a remote execution with a link to the UI. |
| `--tui` | `boolean` | `False` | Show interactive TUI for local execution (requires flyte[tui]). |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--run-project` | `text` | | Run the remote task in this project, only applicable when using `deployed-task` subcommand. |
| `--run-domain` | `text` | | Run the remote task in this domain, only applicable when using `deployed-task` subcommand. |
| `--debug` | `boolean` | `False` | Run the task as a VSCode debug task. Starts a code-server in the container so you can connect via the UI to interactively debug/run the task. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte run deployed-task
**`flyte run deployed-task [OPTIONS] COMMAND [ARGS]...`**
Run remote task from the Flyte backend
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte serve
**`flyte serve [OPTIONS] COMMAND [ARGS]...`**
Serve an app from a Python file using flyte.serve().
This command allows you to serve apps defined with `flyte.app.AppEnvironment`
in your Python files. The serve command will deploy the app to the Flyte backend
and start it, making it accessible via a URL.
Example usage:
CODE47
**Local serving:** Use the `--local` flag to serve the app on localhost without
deploying to the Flyte backend. This is useful for local development and testing:
CODE48
Arguments to the serve command are provided right after the `serve` command and before the file name.
To follow the logs of the served app, use the `--follow` flag:
CODE49
Note: Log streaming is not yet fully implemented and will be added in a future release.
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the app environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
CODE50
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in AppEnvironment:
CODE51
You can specify multiple image arguments:
CODE52
You can specify the `--config` flag to point to a specific Flyte cluster:
CODE53
You can override the default configured project and domain:
CODE54
Other arguments to the serve command are listed below.
Note: This pattern is primarily useful for serving apps defined in tasks.
Serving deployed apps is not currently supported through this CLI command.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when serving the app |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the app deployment. If not provided, the app environment name will be used. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the app. If not provided, the CLI will exit after successfully deploying the app with a link to the UI. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the serve. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--env-var`
`-e` | `text` | `Sentinel.UNSET` | Environment variable to set in the app. Format: KEY=VALUE. Can be specified multiple times. Example: --env-var LOG_LEVEL=DEBUG --env-var DATABASE_URL=postgresql://... |
| `--local` | `boolean` | `False` | Serve the app locally on localhost instead of deploying to the Flyte backend. The app will be served on the port defined in the AppEnvironment. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte start
**`flyte start COMMAND [ARGS]...`**
Start various Flyte services.
#### flyte start tui
**`flyte start tui`**
Launch TUI explore mode to browse past local runs. To use the TUI install `pip install flyte[tui]`
TUI, allows you to explore all your local runs if you have persistence enabled.
Persistence can be enabled in 2 ways,
1. By setting it in the config to record every local run
CODE55
2. By passing it in flyte.init(local_persistence=True)
This will record all `flyte.run` runs, that are local and are within the flyte.init being active.
### flyte update
**`flyte update COMMAND [ARGS]...`**
Update various flyte entities.
#### flyte update app
**`flyte update app [OPTIONS] NAME`**
Update an app by starting or stopping it.
Example usage:
CODE56
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | | Activate or deactivate app. |
| `--wait` | `boolean` | `False` | Wait for the app to reach the desired state. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte update project
**`flyte update project [OPTIONS] ID`**
Update a project's name, description, labels, or archive state.
Example usage:
CODE57
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--name` | `text` | | Update the project display name. |
| `--description` | `text` | | Update the project description. |
| `--label`
`-l` | `text` | `Sentinel.UNSET` | Set labels as key=value pairs. Can be specified multiple times. Replaces all existing labels. |
| `--archive`
`--unarchive` | `boolean` | | Archive or unarchive the project. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte update trigger
**`flyte update trigger [OPTIONS] NAME TASK_NAME`**
Update a trigger.
Example usage:
CODE58
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | `Sentinel.UNSET` | Activate or deactivate the trigger. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte whoami
**`flyte whoami`**
Display the current user information.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk ===
# Flyte SDK
These are the docs for Flyte SDK version 2.0
Flyte is the core Python SDK for the Union and Flyte platforms.
## Subpages
- **Flyte SDK > Classes**
- **Flyte SDK > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/classes ===
# Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Cache** |Cache configuration for a task. |
| **Flyte SDK > Packages > flyte > Cron** |Cron-based automation schedule for use with `Trigger`. |
| **Flyte SDK > Packages > flyte > Device** |Represents a device type, its quantity and partition if applicable. |
| **Flyte SDK > Packages > flyte > Environment** | |
| **Flyte SDK > Packages > flyte > FixedRate** |Fixed-rate (interval-based) automation schedule for use with `Trigger`. |
| **Flyte SDK > Packages > flyte > Image** |Container image specification built using a fluent, two-step pattern:. |
| **Flyte SDK > Packages > flyte > ImageBuild** |Result of an image build operation. |
| **Flyte SDK > Packages > flyte > PodTemplate** |Custom PodTemplate specification for a Task. |
| **Flyte SDK > Packages > flyte > Resources** |Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| **Flyte SDK > Packages > flyte > RetryStrategy** |Retry strategy for the task or task environment. |
| **Flyte SDK > Packages > flyte > ReusePolicy** |Configure a task environment for container reuse across multiple task invocations. |
| **Flyte SDK > Packages > flyte > Secret** |Secrets are used to inject sensitive information into tasks or image build context. |
| **Flyte SDK > Packages > flyte > TaskEnvironment** |Define an execution environment for a set of tasks. |
| **Flyte SDK > Packages > flyte > Timeout** |Timeout class to define a timeout for a task. |
| **Flyte SDK > Packages > flyte > Trigger** |Specification for a scheduled trigger that can be associated with any Flyte task. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint** |Embed an upstream app's endpoint as an app parameter. |
| **Flyte SDK > Packages > flyte.app > AppEnvironment** |Configure a long-running app environment for APIs, dashboards, or model servers. |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment** | |
| **Flyte SDK > Packages > flyte.app > Domain** |Subdomain to use for the domain. |
| **Flyte SDK > Packages > flyte.app > Link** |Custom links to add to the app. |
| **Flyte SDK > Packages > flyte.app > Parameter** |Parameter for application. |
| **Flyte SDK > Packages > flyte.app > Port** | |
| **Flyte SDK > Packages > flyte.app > RunOutput** |Use a run's output for app parameters. |
| **Flyte SDK > Packages > flyte.app > Scaling** |Controls replica count and autoscaling behavior for app environments. |
| **Flyte SDK > Packages > flyte.app > Timeouts** |Timeout configuration for the application. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware** |FastAPI middleware that automatically sets Flyte auth metadata from request headers. |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment** |A pre-built FastAPI app environment for common Flyte webhook operations. |
| **Flyte SDK > Packages > flyte.config > Config** |This the parent configuration object and holds all the underlying configuration object types. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector** |This is the base class for all async connectors, and it defines the interface that all connectors must implement. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnectorExecutorMixin** |This mixin class is used to run the connector task locally, and it's only used for local execution. |
| **Flyte SDK > Packages > flyte.connectors > ConnectorRegistry** |This is the registry for all connectors. |
| **Flyte SDK > Packages > flyte.connectors > ConnectorService** | |
| **Flyte SDK > Packages > flyte.connectors > Resource** |This is the output resource of the job. |
| **Flyte SDK > Packages > flyte.connectors > ResourceMeta** |This is the metadata for the job. |
| **Flyte SDK > Packages > flyte.errors > ActionAbortedError** |This error is raised when an action was aborted, externally. |
| **Flyte SDK > Packages > flyte.errors > ActionNotFoundError** |This error is raised when the user tries to access an action that does not exist. |
| **Flyte SDK > Packages > flyte.errors > BaseRuntimeError** |Base class for all Union runtime errors. |
| **Flyte SDK > Packages > flyte.errors > CodeBundleError** |This error is raised when the code bundle cannot be created, for example when no files are found to bundle. |
| **Flyte SDK > Packages > flyte.errors > CustomError** |This error is raised when the user raises a custom error. |
| **Flyte SDK > Packages > flyte.errors > DeploymentError** |This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| **Flyte SDK > Packages > flyte.errors > ImageBuildError** |This error is raised when the image build fails. |
| **Flyte SDK > Packages > flyte.errors > ImagePullBackOffError** |This error is raised when the image cannot be pulled. |
| **Flyte SDK > Packages > flyte.errors > InitializationError** |This error is raised when the Union system is tried to access without being initialized. |
| **Flyte SDK > Packages > flyte.errors > InlineIOMaxBytesBreached** |This error is raised when the inline IO max bytes limit is breached. |
| **Flyte SDK > Packages > flyte.errors > InvalidImageNameError** |This error is raised when the image name is invalid. |
| **Flyte SDK > Packages > flyte.errors > InvalidPackageError** |Raised when an invalid system package is detected during image build. |
| **Flyte SDK > Packages > flyte.errors > LogsNotYetAvailableError** |This error is raised when the logs are not yet available for a task. |
| **Flyte SDK > Packages > flyte.errors > ModuleLoadError** |This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| **Flyte SDK > Packages > flyte.errors > NonRecoverableError** |Raised when an error is encountered that is not recoverable. |
| **Flyte SDK > Packages > flyte.errors > NotInTaskContextError** |This error is raised when the user tries to access the task context outside of a task. |
| **Flyte SDK > Packages > flyte.errors > OOMError** |This error is raised when the underlying task execution fails because of an out-of-memory error. |
| **Flyte SDK > Packages > flyte.errors > OnlyAsyncIOSupportedError** |This error is raised when the user tries to use sync IO in an async task. |
| **Flyte SDK > Packages > flyte.errors > ParameterMaterializationError** |This error is raised when the user tries to use a Parameter in an App, that has delayed Materialization,. |
| **Flyte SDK > Packages > flyte.errors > PrimaryContainerNotFoundError** |This error is raised when the primary container is not found. |
| **Flyte SDK > Packages > flyte.errors > RemoteTaskNotFoundError** |This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > RemoteTaskUsageError** |This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > RestrictedTypeError** |This error is raised when the user uses a restricted type, for example current a Tuple is not supported for one. |
| **Flyte SDK > Packages > flyte.errors > RetriesExhaustedError** |This error is raised when the underlying task execution fails after all retries have been exhausted. |
| **Flyte SDK > Packages > flyte.errors > RuntimeDataValidationError** |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > RuntimeSystemError** |This error is raised when the underlying task execution fails because of a system error. |
| **Flyte SDK > Packages > flyte.errors > RuntimeUnknownError** |This error is raised when the underlying task execution fails because of an unknown error. |
| **Flyte SDK > Packages > flyte.errors > RuntimeUserError** |This error is raised when the underlying task execution fails because of an error in the user's code. |
| **Flyte SDK > Packages > flyte.errors > SlowDownError** |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > TaskInterruptedError** |This error is raised when the underlying task execution is interrupted. |
| **Flyte SDK > Packages > flyte.errors > TaskTimeoutError** |This error is raised when the underlying task execution runs for longer than the specified timeout. |
| **Flyte SDK > Packages > flyte.errors > TraceDoesNotAllowNestedTasksError** |This error is raised when the user tries to use a task from within a trace. |
| **Flyte SDK > Packages > flyte.errors > UnionRpcError** |This error is raised when communication with the Union server fails. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate** |A task template that wraps an asynchronous functions. |
| **Flyte SDK > Packages > flyte.extend > ImageBuildEngine** |ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate** |Task template is a template for a task that can be executed. |
| **Flyte SDK > Packages > flyte.extras > BatchStats** |Monitoring statistics exposed by `DynamicBatcher. |
| [`flyte.extras.ContainerTask`](../packages/flyte.extras/containertask/page.md) |This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
| [`flyte.extras.DynamicBatcher`](../packages/flyte.extras/dynamicbatcher/page.md) |Batches records from many concurrent producers and runs them through. |
| [`flyte.extras.Prompt`](../packages/flyte.extras/prompt/page.md) |Simple prompt record with built-in token estimation. |
| [`flyte.extras.TokenBatcher`](../packages/flyte.extras/tokenbatcher/page.md) |Token-aware batcher for LLM inference workloads. |
| [`flyte.git.GitStatus`](../packages/flyte.git/gitstatus/page.md) |A class representing the status of a git repository. |
| [`flyte.io.DataFrame`](../packages/flyte.io/dataframe/page.md) |A Flyte meta DataFrame object, that wraps all other dataframe types (usually available as plugins, pandas. |
| [`flyte.io.Dir`](../packages/flyte.io/dir/page.md) |A generic directory class representing a directory with files of a specified format. |
| [`flyte.io.File`](../packages/flyte.io/file/page.md) |A generic file class representing a file with a specified format. |
| [`flyte.io.HashFunction`](../packages/flyte.io/hashfunction/page.md) |A hash method that wraps a user-provided function to compute hashes. |
| [`flyte.io.extend.DataFrameDecoder`](../packages/flyte.io.extend/dataframedecoder/page.md) | |
| [`flyte.io.extend.DataFrameEncoder`](../packages/flyte.io.extend/dataframeencoder/page.md) | |
| [`flyte.io.extend.DataFrameTransformerEngine`](../packages/flyte.io.extend/dataframetransformerengine/page.md) |Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
| [`flyte.models.ActionID`](../packages/flyte.models/actionid/page.md) |A class representing the ID of an Action, nested within a Run. |
| [`flyte.models.ActionPhase`](../packages/flyte.models/actionphase/page.md) |Represents the execution phase of a Flyte action (run). |
| [`flyte.models.Checkpoints`](../packages/flyte.models/checkpoints/page.md) |A class representing the checkpoints for a task. |
| [`flyte.models.CodeBundle`](../packages/flyte.models/codebundle/page.md) |A class representing a code bundle for a task. |
| [`flyte.models.GroupData`](../packages/flyte.models/groupdata/page.md) | |
| [`flyte.models.NativeInterface`](../packages/flyte.models/nativeinterface/page.md) |A class representing the native interface for a task. |
| [`flyte.models.PathRewrite`](../packages/flyte.models/pathrewrite/page.md) |Configuration for rewriting paths during input loading. |
| [`flyte.models.RawDataPath`](../packages/flyte.models/rawdatapath/page.md) |A class representing the raw data path for a task. |
| [`flyte.models.SerializationContext`](../packages/flyte.models/serializationcontext/page.md) |This object holds serialization time contextual information, that can be used when serializing the task and. |
| [`flyte.models.TaskContext`](../packages/flyte.models/taskcontext/page.md) |A context class to hold the current task executions context. |
| [`flyte.prefetch.HuggingFaceModelInfo`](../packages/flyte.prefetch/huggingfacemodelinfo/page.md) |Information about a HuggingFace model to store. |
| [`flyte.prefetch.ShardConfig`](../packages/flyte.prefetch/shardconfig/page.md) |Configuration for model sharding. |
| [`flyte.prefetch.StoredModelInfo`](../packages/flyte.prefetch/storedmodelinfo/page.md) |Information about a stored model. |
| [`flyte.prefetch.VLLMShardArgs`](../packages/flyte.prefetch/vllmshardargs/page.md) |Arguments for sharding a model using vLLM. |
| [`flyte.remote.Action`](../packages/flyte.remote/action/page.md) |A class representing an action. |
| [`flyte.remote.ActionDetails`](../packages/flyte.remote/actiondetails/page.md) |A class representing an action. |
| [`flyte.remote.ActionInputs`](../packages/flyte.remote/actioninputs/page.md) |A class representing the inputs of an action. |
| [`flyte.remote.ActionOutputs`](../packages/flyte.remote/actionoutputs/page.md) |A class representing the outputs of an action. |
| [`flyte.remote.App`](../packages/flyte.remote/app/page.md) | |
| [`flyte.remote.Project`](../packages/flyte.remote/project/page.md) |A class representing a project in the Union API. |
| [`flyte.remote.Run`](../packages/flyte.remote/run/page.md) |A class representing a run of a task. |
| [`flyte.remote.RunDetails`](../packages/flyte.remote/rundetails/page.md) |A class representing a run of a task. |
| [`flyte.remote.Secret`](../packages/flyte.remote/secret/page.md) | |
| [`flyte.remote.Task`](../packages/flyte.remote/task/page.md) | |
| [`flyte.remote.TaskDetails`](../packages/flyte.remote/taskdetails/page.md) | |
| [`flyte.remote.TimeFilter`](../packages/flyte.remote/timefilter/page.md) |Filter for time-based fields (e. |
| [`flyte.remote.Trigger`](../packages/flyte.remote/trigger/page.md) |Represents a trigger in the Flyte platform. |
| [`flyte.remote.User`](../packages/flyte.remote/user/page.md) |Represents a user in the Flyte platform. |
| [`flyte.report.Report`](../packages/flyte.report/report/page.md) | |
| [`flyte.sandbox.CodeTaskTemplate`](../packages/flyte.sandbox/codetasktemplate/page.md) |A sandboxed task created from a code string rather than a decorated function. |
| [`flyte.sandbox.ImageConfig`](../packages/flyte.sandbox/imageconfig/page.md) |Configuration for Docker image building at runtime. |
| [`flyte.sandbox.SandboxedConfig`](../packages/flyte.sandbox/sandboxedconfig/page.md) |Configuration for a sandboxed task executed via Monty. |
| [`flyte.sandbox.SandboxedTaskTemplate`](../packages/flyte.sandbox/sandboxedtasktemplate/page.md) |A task template that executes the function body in a Monty sandbox. |
| [`flyte.storage.ABFS`](../packages/flyte.storage/abfs/page.md) |Any Azure Blob Storage specific configuration. |
| [`flyte.storage.GCS`](../packages/flyte.storage/gcs/page.md) |Any GCS specific configuration. |
| [`flyte.storage.S3`](../packages/flyte.storage/s3/page.md) |S3 specific configuration. |
| [`flyte.storage.Storage`](../packages/flyte.storage/storage/page.md) |Data storage configuration that applies across any provider. |
| [`flyte.syncify.Syncify`](../packages/flyte.syncify/syncify/page.md) |A decorator to convert asynchronous functions or methods into synchronous ones. |
| [`flyte.types.FlytePickle`](../packages/flyte.types/flytepickle/page.md) |This type is only used by flytekit internally. |
| [`flyte.types.TypeEngine`](../packages/flyte.types/typeengine/page.md) |Core Extensible TypeEngine of Flytekit. |
| [`flyte.types.TypeTransformer`](../packages/flyte.types/typetransformer/page.md) |Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
| [`flyte.types.TypeTransformerFailedError`](../packages/flyte.types/typetransformerfailederror/page.md) | |
# Protocols
| Protocol | Description |
|-|-|
| [`flyte.AppHandle`](../packages/flyte/apphandle/page.md) |Protocol defining the common interface between local and remote app handles. |
| [`flyte.CachePolicy`](../packages/flyte/cachepolicy/page.md) |Protocol for custom cache version strategies. |
| [`flyte.Link`](../packages/flyte/link/page.md) | |
| [`flyte.extend.ImageBuilder`](../packages/flyte.extend/imagebuilder/page.md) | |
| [`flyte.extend.ImageChecker`](../packages/flyte.extend/imagechecker/page.md) | |
| [`flyte.extras.CostEstimator`](../packages/flyte.extras/costestimator/page.md) |Protocol for records that can estimate their own processing cost. |
| [`flyte.extras.TokenEstimator`](../packages/flyte.extras/tokenestimator/page.md) |Protocol for records that can estimate their own token count. |
| [`flyte.types.Renderable`](../packages/flyte.types/renderable/page.md) | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages ===
# Packages
| Package | Description |
|-|-|
| **Flyte SDK > Packages > flyte** | Flyte SDK for authoring compound AI applications, services and workflows. |
| **Flyte SDK > Packages > flyte.app** | |
| **Flyte SDK > Packages > flyte.app.extras** | |
| **Flyte SDK > Packages > flyte.config** | |
| **Flyte SDK > Packages > flyte.connectors** | |
| **Flyte SDK > Packages > flyte.connectors.utils** | |
| **Flyte SDK > Packages > flyte.durable** | Flyte durable utilities. |
| **Flyte SDK > Packages > flyte.errors** | Exceptions raised by Union. |
| **Flyte SDK > Packages > flyte.extend** | |
| **Flyte SDK > Packages > flyte.extras** | Flyte extras package. |
| **Flyte SDK > Packages > flyte.git** | |
| **Flyte SDK > Packages > flyte.io** | ## IO data types. |
| **Flyte SDK > Packages > flyte.io.extend** | |
| **Flyte SDK > Packages > flyte.models** | |
| **Flyte SDK > Packages > flyte.prefetch** | Prefetch utilities for Flyte. |
| **Flyte SDK > Packages > flyte.remote** | Remote Entities that are accessible from the Union Server once deployed or created. |
| **Flyte SDK > Packages > flyte.report** | |
| **Flyte SDK > Packages > flyte.sandbox** | Sandbox utilities for running isolated code inside Flyte tasks. |
| **Flyte SDK > Packages > flyte.storage** | |
| **Flyte SDK > Packages > flyte.syncify** | # Syncify Module. |
| **Flyte SDK > Packages > flyte.types** | # Flyte Type System. |
## Subpages
- **Flyte SDK > Packages > flyte**
- **Flyte SDK > Packages > flyte.app**
- **Flyte SDK > Packages > flyte.app.extras**
- **Flyte SDK > Packages > flyte.config**
- **Flyte SDK > Packages > flyte.connectors**
- **Flyte SDK > Packages > flyte.connectors.utils**
- **Flyte SDK > Packages > flyte.durable**
- **Flyte SDK > Packages > flyte.errors**
- **Flyte SDK > Packages > flyte.extend**
- **Flyte SDK > Packages > flyte.extras**
- **Flyte SDK > Packages > flyte.git**
- **Flyte SDK > Packages > flyte.io**
- **Flyte SDK > Packages > flyte.io.extend**
- **Flyte SDK > Packages > flyte.models**
- **Flyte SDK > Packages > flyte.prefetch**
- **Flyte SDK > Packages > flyte.remote**
- **Flyte SDK > Packages > flyte.report**
- **Flyte SDK > Packages > flyte.sandbox**
- **Flyte SDK > Packages > flyte.storage**
- **Flyte SDK > Packages > flyte.syncify**
- **Flyte SDK > Packages > flyte.types**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte ===
# flyte
Flyte SDK for authoring compound AI applications, services and workflows.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Cache** | Cache configuration for a task. |
| **Flyte SDK > Packages > flyte > Cron** | Cron-based automation schedule for use with `Trigger`. |
| **Flyte SDK > Packages > flyte > Device** | Represents a device type, its quantity and partition if applicable. |
| **Flyte SDK > Packages > flyte > Environment** | |
| **Flyte SDK > Packages > flyte > FixedRate** | Fixed-rate (interval-based) automation schedule for use with `Trigger`. |
| **Flyte SDK > Packages > flyte > Image** | Container image specification built using a fluent, two-step pattern:. |
| **Flyte SDK > Packages > flyte > ImageBuild** | Result of an image build operation. |
| **Flyte SDK > Packages > flyte > PodTemplate** | Custom PodTemplate specification for a Task. |
| **Flyte SDK > Packages > flyte > Resources** | Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| **Flyte SDK > Packages > flyte > RetryStrategy** | Retry strategy for the task or task environment. |
| **Flyte SDK > Packages > flyte > ReusePolicy** | Configure a task environment for container reuse across multiple task invocations. |
| **Flyte SDK > Packages > flyte > Secret** | Secrets are used to inject sensitive information into tasks or image build context. |
| **Flyte SDK > Packages > flyte > TaskEnvironment** | Define an execution environment for a set of tasks. |
| **Flyte SDK > Packages > flyte > Timeout** | Timeout class to define a timeout for a task. |
| **Flyte SDK > Packages > flyte > Trigger** | Specification for a scheduled trigger that can be associated with any Flyte task. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte > AppHandle** | Protocol defining the common interface between local and remote app handles. |
| **Flyte SDK > Packages > flyte > CachePolicy** | Protocol for custom cache version strategies. |
| **Flyte SDK > Packages > flyte > Link** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Methods > AMD_GPU()** | Create an AMD GPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > GPU()** | Create a GPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > HABANA_GAUDI()** | Create a Habana Gaudi device instance. |
| **Flyte SDK > Packages > flyte > Methods > Neuron()** | Create a Neuron device instance. |
| **Flyte SDK > Packages > flyte > Methods > TPU()** | Create a TPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > build()** | Build an image. |
| **Flyte SDK > Packages > flyte > Methods > build_images()** | Build the images for the given environments. |
| **Flyte SDK > Packages > flyte > Methods > ctx()** | Returns flyte. |
| **Flyte SDK > Packages > flyte > Methods > current_domain()** | Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration. |
| **Flyte SDK > Packages > flyte > Methods > current_project()** | Returns the current project from the Runtime environment (on the cluster) or from the initialized configuration. |
| **Flyte SDK > Packages > flyte > Methods > custom_context()** | Synchronous context manager to set input context for tasks spawned within this block. |
| **Flyte SDK > Packages > flyte > Methods > deploy()** | Deploy the given environment or list of environments. |
| **Flyte SDK > Packages > flyte > Methods > get_custom_context()** | Get the current input context. |
| **Flyte SDK > Packages > flyte > Methods > group()** | Create a new group with the given name. |
| **Flyte SDK > Packages > flyte > Methods > init()** | Initialize the Flyte system with the given configuration. |
| **Flyte SDK > Packages > flyte > Methods > init_from_api_key()** | Initialize the Flyte system using an API key for authentication. |
| **Flyte SDK > Packages > flyte > Methods > init_from_config()** | Initialize the Flyte system using a configuration file or Config object. |
| **Flyte SDK > Packages > flyte > Methods > init_in_cluster()** | |
| **Flyte SDK > Packages > flyte > Methods > init_passthrough()** | Initialize the Flyte system with passthrough authentication. |
| **Flyte SDK > Packages > flyte > Methods > map()** | Map a function over the provided arguments with concurrent execution. |
| **Flyte SDK > Packages > flyte > Methods > run()** | Run a task with the given parameters. |
| **Flyte SDK > Packages > flyte > Methods > run_python_script()** | Package and run a Python script on a remote Flyte cluster. |
| **Flyte SDK > Packages > flyte > Methods > serve()** | Serve a Flyte app using an AppEnvironment. |
| **Flyte SDK > Packages > flyte > trace()** | A decorator that traces function execution with timing information. |
| **Flyte SDK > Packages > flyte > version()** | Returns the version of the Flyte SDK. |
| **Flyte SDK > Packages > flyte > with_runcontext()** | Launch a new run with the given parameters as the context. |
| **Flyte SDK > Packages > flyte > with_servecontext()** | Create a serve context with custom configuration. |
### Variables
| Property | Type | Description |
|-|-|-|
| `TimeoutType` | `UnionType` | |
| `TriggerTime` | `_trigger_time` | |
| `__version__` | `str` | |
| `logger` | `Logger` | |
## Methods
#### AMD_GPU()
```python
def AMD_GPU(
device: typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X'],
) -> flyte._resources.Device
```
Create an AMD GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X']` | Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"). |
**Returns:** Device instance.
#### GPU()
```python
def GPU(
device: typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000', 'GB10'],
quantity: typing.Literal[1, 2, 3, 4, 5, 6, 7, 8],
partition: typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.10gb', '1g.20gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType],
) -> flyte._resources.Device
```
Create a GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000', 'GB10']` | The type of GPU (e.g., "T4", "A100"). |
| `quantity` | `typing.Literal[1, 2, 3, 4, 5, 6, 7, 8]` | The number of GPUs of this type. |
| `partition` | `typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.10gb', '1g.20gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType]` | The partition of the GPU (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus). |
**Returns:** Device instance.
#### HABANA_GAUDI()
```python
def HABANA_GAUDI(
device: typing.Literal['Gaudi1'],
) -> flyte._resources.Device
```
Create a Habana Gaudi device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Gaudi1']` | Device type (e.g., "Gaudi1"). |
**Returns:** Device instance.
#### Neuron()
```python
def Neuron(
device: typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u'],
) -> flyte._resources.Device
```
Create a Neuron device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u']` | Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"). |
**Returns:** Device instance.
#### TPU()
```python
def TPU(
device: typing.Literal['V5P', 'V6E'],
partition: typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType],
)
```
Create a TPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['V5P', 'V6E']` | Device type (e.g., "V5P", "V6E"). |
| `partition` | `typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType]` | Partition of the TPU (e.g., "1x1", "2x2", ...). |
**Returns:** Device instance.
#### build()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build.aio()`.
```python
def build(
image: Image,
dry_run: bool,
force: bool,
wait: bool,
) -> ImageBuild
```
Build an image. The existing async context will be used.
Example:
```
import flyte
image = flyte.Image("example_image")
if __name__ == "__main__":
result = asyncio.run(flyte.build.aio(image))
print(result.uri)
```
| Parameter | Type | Description |
|-|-|-|
| `image` | `Image` | The image(s) to build. |
| `dry_run` | `bool` | Tell the builder to not actually build. Different builders will have different behaviors. |
| `force` | `bool` | Skip the existence check and force a rebuild. When using the remote builder, this also sets overwrite_cache=True on the build run. |
| `wait` | `bool` | Wait for the build to finish. If wait is False, the function will return immediately and the build will run in the background. |
**Returns:** An ImageBuild object with the image URI and remote run (if applicable).
#### build_images()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build_images.aio()`.
```python
def build_images(
envs: Environment,
) -> ImageCache
```
Build the images for the given environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment to build images for. |
**Returns:** ImageCache containing the built images.
#### ctx()
```python
def ctx()
```
Returns flyte.models.TaskContext if within a task context, else None
Note: Only use this in task code and not module level.
#### current_domain()
```python
def current_domain()
```
Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration.
This is safe to be used during `deploy`, `run` and within `task` code.
NOTE: This will not work if you deploy a task to a domain and then run it in another domain.
Raises InitializationError if the configuration is not initialized or domain is not set.
**Returns:** The current domain
#### current_project()
```python
def current_project()
```
Returns the current project from the Runtime environment (on the cluster) or from the initialized configuration.
This is safe to be used during `deploy`, `run` and within `task` code.
NOTE: This will not work if you deploy a task to a project and then run it in another project.
Raises InitializationError if the configuration is not initialized or project is not set.
**Returns:** The current project
#### custom_context()
```python
def custom_context(
context: str,
)
```
Synchronous context manager to set input context for tasks spawned within this block.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
ctx = flyte.get_custom_context()
print(ctx)
@env.task
def main():
# context can be passed via a context manager
with flyte.custom_context(project="my-project"):
t1() # will have {'project': 'my-project'} as context
```
| Parameter | Type | Description |
|-|-|-|
| `context` | `str` | Key-value pairs to set as input context |
#### deploy()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await deploy.aio()`.
```python
def deploy(
envs: Environment,
dryrun: bool,
version: str | None,
interactive_mode: bool | None,
copy_style: CopyFiles,
) -> List[Deployment]
```
Deploy the given environment or list of environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment or list of environments to deploy. |
| `dryrun` | `bool` | dryrun mode, if True, the deployment will not be applied to the control plane. |
| `version` | `str \| None` | version of the deployment, if None, the version will be computed from the code bundle. TODO: Support for interactive_mode |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `copy_style` | `CopyFiles` | Copy style to use when running the task |
**Returns:** Deployment object containing the deployed environments and tasks.
#### get_custom_context()
```python
def get_custom_context()
```
Get the current input context. This can be used within a task to retrieve
context metadata that was passed to the action.
Context will automatically propagate to sub-actions.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
# context can be retrieved with `get_custom_context`
ctx = flyte.get_custom_context()
print(ctx) # {'project': '...', 'entity': '...'}
CODE15 python
def group(
name: str,
)
CODE16 python
@task
async def my_task():
...
with group("my_group"):
t1(x,y) # tasks in this block will be grouped under "my_group"
...
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the group |
#### init()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init.aio()`.
CODE17
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | Optional organization override for the client. Should be set by auth instead. |
| `project` | `str \| None` | Optional project name (not used in this implementation) |
| `domain` | `str \| None` | Optional domain name (not used in this implementation) |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files. This is useful for determining the root directory for the current project, and for locating files like config etc. also use to determine all the code that needs to be copied to the remote location. defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd. |
| `log_level` | `int \| None` | Optional logging level for the logger, default is set using the default initialization policies |
| `log_format` | `LogFormat \| None` | Optional logging format for the logger, default is "console" |
| `reset_root_logger` | `bool` | By default, we clear out root logger handlers and set up our own. |
| `endpoint` | `str \| None` | Optional API endpoint URL |
| `headless` | `bool` | Optional Whether to run in headless mode |
| `insecure` | `bool` | insecure flag for the client |
| `insecure_skip_verify` | `bool` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `str \| None` | [optional] str Root Cert to be loaded and used to verify admin |
| `auth_type` | `AuthType` | The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow) |
| `command` | `List[str] \| None` | This command is executed to return a token using an external process |
| `proxy_command` | `List[str] \| None` | This command is executed to return a token for proxy authorization using an external process |
| `api_key` | `str \| None` | Optional API key for authentication |
| `client_id` | `str \| None` | This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. |
| `client_credentials_secret` | `str \| None` | Used for service auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the password directly from the environment variable. Note that this is less secure! Please only use this if mounting the secret as a file is impossible |
| `auth_client_config` | `ClientConfig \| None` | Optional client configuration for authentication |
| `rpc_retries` | `int` | [optional] int Number of times to retry the platform calls |
| `http_proxy_url` | `str \| None` | [optional] HTTP Proxy to be used for OAuth requests |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000, so limit larger than batch_size will be split into multiple requests. |
| `image_builder` | `ImageBuildEngine.ImageBuilderType` | Optional image builder configuration, if not provided, the default image builder will be used. |
| `images` | `typing.Dict[str, str] \| None` | Optional dict of images that can be used by referencing the image name. |
| `source_config_path` | `Optional[Path]` | Optional path to the source configuration file (This is only used for documentation) |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). |
| `load_plugin_type_transformers` | `bool` | If enabled (default True), load the type transformer plugins registered under the "flyte.plugins.types" entry point group. |
| `local_persistence` | `bool` | Whether to enable SQLite persistence for local run metadata (default |
**Returns:** None
#### init_from_api_key()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_from_api_key.aio()`.
CODE18
Initialize the Flyte system using an API key for authentication. This is a convenience
method for API key-based authentication. Thread-safe implementation.
The API key should be an encoded API key that contains the endpoint, client ID, client secret,
and organization information. You can obtain this encoded API key from your Flyte administrator
or cloud provider.
| Parameter | Type | Description |
|-|-|-|
| `api_key` | `str \| None` | Optional encoded API key for authentication. If None, reads from FLYTE_API_KEY environment variable. The API key is a base64-encoded string containing endpoint, client_id, client_secret, and org information. |
| `project` | `str \| None` | Optional project name |
| `domain` | `str \| None` | Optional domain name |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files. defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd. |
| `log_level` | `int \| None` | Optional logging level for the logger |
| `log_format` | `LogFormat \| None` | Optional logging format for the logger, default is "console" |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000 |
| `image_builder` | `ImageBuildEngine.ImageBuilderType` | Optional image builder configuration |
| `images` | `typing.Dict[str, str] \| None` | Optional dict of images that can be used by referencing the image name |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True) |
**Returns:** None
#### init_from_config()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_from_config.aio()`.
CODE19
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
other Flyte remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `path_or_config` | `str \| Path \| Config \| None` | Path to the configuration file or Config object |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files like config etc. For example if one uses the copy-style=="all", it is essential to determine the root directory for the current project. If not provided, it defaults to the editable install directory or if not available, the current working directory. |
| `log_level` | `int \| None` | Optional logging level for the framework logger, default is set using the default initialization policies |
| `log_format` | `LogFormat` | Optional logging format for the logger, default is "console" |
| `project` | `str \| None` | Project name, this will override any project names in the configuration file |
| `domain` | `str \| None` | Domain name, this will override any domain names in the configuration file |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000 |
| `image_builder` | `ImageBuildEngine.ImageBuilderType \| None` | Optional image builder configuration, if provided, will override any defaults set in the configuration. |
| `images` | `tuple[str, ...] \| None` | List of image strings in format "imagename=imageuri" or just "imageuri". |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). |
**Returns:** None
#### init_in_cluster()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_in_cluster.aio()`.
CODE20
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `api_key` | `str \| None` | |
| `endpoint` | `str \| None` | |
| `insecure` | `bool` | |
#### init_passthrough()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_passthrough.aio()`.
CODE21
Initialize the Flyte system with passthrough authentication.
This authentication mode allows you to pass custom authentication metadata
using the `flyte.remote.auth_metadata()` context manager.
The endpoint is automatically configured from the environment if in a flyte cluster with endpoint injected.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str \| None` | Optional API endpoint URL |
| `org` | `str \| None` | Optional organization name |
| `project` | `str \| None` | Optional project name |
| `domain` | `str \| None` | Optional domain name |
| `insecure` | `bool` | Whether to use an insecure channel |
**Returns:** Dictionary of remote kwargs used for initialization
#### map()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flyte.map.aio()`.
CODE22
Map a function over the provided arguments with concurrent execution.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Union[flyte._task.AsyncFunctionTaskTemplate[~P, ~R, ~F], functools.partial[~R]]` | The async function to map. |
| `args` | `*args` | Positional arguments to pass to the function (iterables that will be zipped). |
| `group_name` | `str \| None` | The name of the group for the mapped tasks. |
| `concurrency` | `int` | The maximum number of concurrent tasks to run. If 0, run all tasks concurrently. |
| `return_exceptions` | `bool` | If True, yield exceptions instead of raising them. |
**Returns:** AsyncIterator yielding results in order.
#### run()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await run.aio()`.
CODE23
Run a task with the given parameters
| Parameter | Type | Description |
|-|-|-|
| `task` | `TaskTemplate[P, R, F]` | task to run |
| `args` | `*args` | args to pass to the task |
| `kwargs` | `**kwargs` | kwargs to pass to the task |
**Returns:** Run | Result of the task
#### run_python_script()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await run_python_script.aio()`.
CODE24
Package and run a Python script on a remote Flyte cluster.
Bundles the script into a Flyte code bundle and executes it remotely
with the requested resources. Unlike `interactive_mode` (which
pickles the task), this approach uses an `InternalTaskResolver`
so the task can be properly debugged with `debug=True`.
Project and domain are read from the init config (set via `flyte.init()`
or `flyte.init_from_config()`), consistent with `flyte.run()`.
Example::
import flyte
from pathlib import Path
flyte.init(endpoint="my-cluster.example.com")
# With a list of packages (auto-builds image)
run = flyte.run_python_script(
Path("train.py"),
gpu=1,
gpu_type="A100",
memory="64Gi",
image=["torch", "transformers"],
)
print(run.url)
# With a custom Image object
img = flyte.Image.from_debian_base(name="my-img").with_pip_packages("numpy")
run = flyte.run_python_script(Path("analysis.py"), image=img)
| Parameter | Type | Description |
|-|-|-|
| `script` | `pathlib.Path` | Path to the Python script to run. |
| `cpu` | `int` | Number of CPUs to request (default |
| `memory` | `str` | Memory to request, e.g. `"16Gi"` (default |
| `gpu` | `int` | Number of GPUs to request (default |
| `gpu_type` | `str` | GPU accelerator type Only used when `gpu > 0` (default: `"T4"`). |
| `image` | `Union[Image, List[str], None]` | Container image to use. Accepts either - A `flyte.Image` object for full control over the image. - A `list[str]` of pip package names to install on top of the default Debian base image (e.g. `["torch", "transformers"]`). - `None` to use a plain Debian base image (default). |
| `timeout` | `int` | Task timeout in seconds (default |
| `extra_args` | `Optional[List[str]]` | Extra arguments passed to the script. |
| `queue` | `Optional[str]` | Flyte queue / cluster override. |
| `wait` | `bool` | If True, block until execution completes before returning. |
| `name` | `Optional[str]` | Run name. If omitted, a random name is generated. |
| `debug` | `bool` | If True, run the task as a VS Code debug task, starting a code-server in the container so you can connect via the UI to interactively debug/run the task. |
| `output_dir` | `Optional[str]` | |
**Returns:** A `flyte.remote.Run` handle for the remote execution.
#### serve()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await serve.aio()`.
CODE25
Serve a Flyte app using an AppEnvironment.
This is the simple, direct way to serve an app. For more control over
deployment settings (env vars, cluster pool, etc.), use with_servecontext().
Example:
CODE26
See Also:
with_servecontext: For customizing deployment settings
| Parameter | Type | Description |
|-|-|-|
| `app_env` | `'AppEnvironment'` | The app environment to serve |
**Returns**
An `AppHandle` β either a `_LocalApp` (local) or `App` (remote)
#### trace()
CODE27
A decorator that traces function execution with timing information.
Works with regular functions, async functions, and async generators/iterators.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Callable[..., ~T]` | |
#### version()
CODE28
Returns the version of the Flyte SDK.
#### with_runcontext()
CODE29
Launch a new run with the given parameters as the context.
Example:
CODE30
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Mode \| None` | Optional The mode to use for the run, if not provided, it will be computed from flyte.init |
| `name` | `Optional[str]` | Optional The name to use for the run |
| `service_account` | `Optional[str]` | Optional The service account to use for the run context |
| `version` | `Optional[str]` | Optional The version to use for the run, if not provided, it will be computed from the code bundle |
| `copy_style` | `CopyFiles` | Optional The copy style to use for the run context |
| `dry_run` | `bool` | Optional If true, the run will not be executed, but the bundle will be created |
| `copy_bundle_to` | `pathlib.Path \| None` | When dry_run is True, the bundle will be copied to this location if specified |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `raw_data_path` | `str \| None` | Use this path to store the raw data for the run for local and remote, and can be used to store raw data in specific locations. |
| `run_base_dir` | `str \| None` | Optional The base directory to use for the run. This is used to store the metadata for the run, that is passed between tasks. |
| `overwrite_cache` | `bool` | Optional If true, the cache will be overwritten for the run |
| `project` | `str \| None` | Optional The project to use for the run |
| `domain` | `str \| None` | Optional The domain to use for the run |
| `env_vars` | `Dict[str, str] \| None` | Optional Environment variables to set for the run |
| `labels` | `Dict[str, str] \| None` | Optional Labels to set for the run |
| `annotations` | `Dict[str, str] \| None` | Optional Annotations to set for the run |
| `interruptible` | `bool \| None` | Optional If true, the run can be scheduled on interruptible instances and false implies that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the original setting on all tasks is retained. |
| `log_level` | `int \| None` | Optional Log level to set for the run. If not provided, it will be set to the default log level set using `flyte.init()` |
| `log_format` | `LogFormat` | Optional Log format to set for the run. If not provided, it will be set to the default log format |
| `reset_root_logger` | `bool` | If true, the root logger will be preserved and not modified by Flyte. |
| `disable_run_cache` | `bool` | Optional If true, the run cache will be disabled. This is useful for testing purposes. |
| `queue` | `Optional[str]` | Optional The queue to use for the run. This is used to specify the cluster to use for the run. |
| `custom_context` | `Dict[str, str] \| None` | Optional global input context to pass to the task. This will be available via get_custom_context() within the task and will automatically propagate to sub-tasks. Acts as base/default values that can be overridden by context managers in the code. |
| `cache_lookup_scope` | `CacheLookupScope` | Optional Scope to use for the run. This is used to specify the scope to use for cache lookups. If not specified, it will be set to the default scope (global unless overridden at the system level). |
| `preserve_original_types` | `bool` | Optional If true, the type engine will preserve original types (e.g., pd.DataFrame) when guessing python types from literal types. If false (default), it will return the generic flyte.io.DataFrame. This option is automatically set to True if interactive_mode is True unless overridden explicitly by this parameter. |
| `debug` | `bool` | Optional If true, the task will be run as a VSCode debug task, starting a code-server in the container so users can connect via the UI to interactively debug/run the task. |
| `_tracker` | `Any` | This is an internal only parameter used by the CLI to render the TUI. |
**Returns:** runner
#### with_servecontext()
CODE31
Create a serve context with custom configuration.
This function allows you to customize how an app is served, including
overriding environment variables, cluster pool, logging, and other deployment settings.
Use `mode="local"` to serve the app on localhost (non-blocking) so you can
immediately invoke tasks that call the app endpoint:
CODE32
Use `mode="remote"` (or omit *mode* when a Flyte client is configured) to
deploy the app to the Flyte backend:
CODE33
| Parameter | Type | Description |
|-|-|-|
| `mode` | `ServeMode \| None` | "local" to run on localhost, "remote" to deploy to the Flyte backend. When `None` the mode is inferred from the current configuration. |
| `version` | `Optional[str]` | Optional version override for the app deployment |
| `copy_style` | `CopyFiles` | Code bundle copy style. Options: "loaded_modules", "all", "none" (default: "loaded_modules") |
| `dry_run` | `bool` | If True, don't actually deploy (default: False) |
| `project` | `str \| None` | Optional project override |
| `domain` | `str \| None` | Optional domain override |
| `env_vars` | `dict[str, str] \| None` | Optional environment variables to inject/override in the app container |
| `parameter_values` | `dict[str, dict[str, str \| flyte.io.File \| flyte.io.Dir]] \| None` | Optional parameter values to inject/override in the app container. Must be a dictionary that maps app environment names to a dictionary of parameter names to values. |
| `cluster_pool` | `str \| None` | Optional cluster pool to deploy the app to |
| `log_level` | `int \| None` | Optional log level (e.g., logging.DEBUG, logging.INFO). If not provided, uses init config or default |
| `log_format` | `LogFormat` | Optional log format ("console" or "json", default: "console") |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. This is used to determine if the app should be served in interactive mode or not. |
| `copy_bundle_to` | `pathlib.Path \| None` | When dry_run is True, the bundle will be copied to this location if specified |
| `deactivate_timeout` | `float \| None` | Timeout in seconds for waiting for the app to stop during `deactivate(wait=True)`. Defaults to 6 s. |
| `activate_timeout` | `float \| None` | Total timeout in seconds when polling the health-check endpoint during `activate(wait=True)`. Defaults to 60 s. |
| `health_check_timeout` | `float \| None` | Per-request timeout in seconds for each health-check HTTP request. Defaults to 2 s. |
| `health_check_interval` | `float \| None` | Interval in seconds between consecutive health-check polls. Defaults to 1 s. |
| `health_check_path` | `str \| None` | URL path used for the local health-check probe (e.g. `"/healthz"`). Defaults to `"/health"`. |
**Returns**
_Serve: Serve context manager with configured settings
**Raises**
| Exception | Description |
|-|-|
| `NotImplementedError` | If called from a notebook/interactive environment (remote mode only) |
> [!NOTE]
> - Apps do not support pickle-based bundling (interactive mode)
> - LOG_LEVEL and LOG_FORMAT are automatically set as env vars if not explicitly provided in env_vars
> - The env_vars and cluster_pool overrides mutate the app IDL after creation
> - This is a temporary solution until the API natively supports these fields
## Subpages
- **Flyte SDK > Packages > flyte > AppHandle**
- **Flyte SDK > Packages > flyte > Cache**
- **Flyte SDK > Packages > flyte > CachePolicy**
- **Flyte SDK > Packages > flyte > Cron**
- **Flyte SDK > Packages > flyte > Device**
- **Flyte SDK > Packages > flyte > Environment**
- **Flyte SDK > Packages > flyte > FixedRate**
- **Flyte SDK > Packages > flyte > Image**
- **Flyte SDK > Packages > flyte > ImageBuild**
- **Flyte SDK > Packages > flyte > Link**
- **Flyte SDK > Packages > flyte > PodTemplate**
- **Flyte SDK > Packages > flyte > Resources**
- **Flyte SDK > Packages > flyte > RetryStrategy**
- **Flyte SDK > Packages > flyte > ReusePolicy**
- **Flyte SDK > Packages > flyte > Secret**
- **Flyte SDK > Packages > flyte > TaskEnvironment**
- **Flyte SDK > Packages > flyte > Timeout**
- **Flyte SDK > Packages > flyte > Trigger**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/apphandle ===
# AppHandle
**Package:** `flyte`
Protocol defining the common interface between local and remote app handles.
Both `_LocalApp` (local serving) and `App` (remote serving) satisfy this
protocol, enabling calling code to work uniformly regardless of the serving mode.
```python
protocol AppHandle()
```
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
| `name` | `None` | |
| `url` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > AppHandle > Methods > activate()** | |
| **Flyte SDK > Packages > flyte > AppHandle > Methods > deactivate()** | |
| **Flyte SDK > Packages > flyte > AppHandle > Methods > ephemeral_ctx()** | |
| **Flyte SDK > Packages > flyte > AppHandle > Methods > ephemeral_ctx_sync()** | |
| **Flyte SDK > Packages > flyte > AppHandle > Methods > is_active()** | |
| **Flyte SDK > Packages > flyte > AppHandle > Methods > is_deactivated()** | |
### activate()
```python
def activate(
wait: bool,
) -> AppHandle
```
| Parameter | Type | Description |
|-|-|-|
| `wait` | `bool` | |
### deactivate()
```python
def deactivate(
wait: bool,
) -> AppHandle
```
| Parameter | Type | Description |
|-|-|-|
| `wait` | `bool` | |
### ephemeral_ctx()
```python
def ephemeral_ctx()
```
### ephemeral_ctx_sync()
```python
def ephemeral_ctx_sync()
```
### is_active()
```python
def is_active()
```
### is_deactivated()
```python
def is_deactivated()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/cache ===
# Cache
**Package:** `flyte`
Cache configuration for a task.
Three cache behaviors are available:
- `"auto"` β Cache version is computed automatically from cache policies
(default: `FunctionBodyPolicy`, which hashes the function source code).
Any change to the function body invalidates the cache.
- `"override"` β You provide an explicit `version_override` string.
Cache is only invalidated when you change the version.
- `"disable"` β Caching is disabled; task always re-executes.
Set via `TaskEnvironment(cache=...)`, `@env.task(cache=...)`, or
`task.override(cache=...)`.
## Parameters
```python
class Cache(
behavior: typing.Literal['auto', 'override', 'disable'],
version_override: typing.Optional[str],
serialize: bool,
ignored_inputs: typing.Union[typing.Tuple[str, ...], str],
salt: str,
policies: typing.Union[typing.List[flyte._cache.cache.CachePolicy], flyte._cache.cache.CachePolicy, NoneType],
)
```
| Parameter | Type | Description |
|-|-|-|
| `behavior` | `typing.Literal['auto', 'override', 'disable']` | Cache behavior β `"auto"`, `"override"`, or `"disable"`. |
| `version_override` | `typing.Optional[str]` | Explicit cache version string. Only used when `behavior="override"`. |
| `serialize` | `bool` | If `True`, concurrent executions with identical inputs will be serialized β only one runs and the rest wait for and reuse the cached result. Default `False`. |
| `ignored_inputs` | `typing.Union[typing.Tuple[str, ...], str]` | Input parameter names to exclude from the cache key. Useful when some inputs (e.g., timestamps) shouldn't affect caching. |
| `salt` | `str` | Additional salt for cache key generation. Use to create separate cache namespaces (e.g., `salt="v2"` to invalidate all existing caches). |
| `policies` | `typing.Union[typing.List[flyte._cache.cache.CachePolicy], flyte._cache.cache.CachePolicy, NoneType]` | Cache policies for version generation. Defaults to `[FunctionBodyPolicy()]` when `behavior="auto"`. Provide a custom `CachePolicy` implementation for alternative versioning strategies. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Cache > Methods > get_ignored_inputs()** | |
| **Flyte SDK > Packages > flyte > Cache > Methods > get_version()** | |
| **Flyte SDK > Packages > flyte > Cache > Methods > is_enabled()** | Check if the cache policy is enabled. |
### get_ignored_inputs()
```python
def get_ignored_inputs()
```
### get_version()
```python
def get_version(
params: typing.Optional[flyte._cache.cache.VersionParameters],
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `params` | `typing.Optional[flyte._cache.cache.VersionParameters]` | |
### is_enabled()
```python
def is_enabled()
```
Check if the cache policy is enabled.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/cachepolicy ===
# CachePolicy
**Package:** `flyte`
Protocol for custom cache version strategies.
Implement `get_version(salt, params) -> str` to define how cache versions
are computed. The default implementation is `FunctionBodyPolicy`, which
hashes the function source code.
Example custom policy:
```python
class GitHashPolicy:
def get_version(self, salt: str, params: VersionParameters) -> str:
import subprocess
git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
return hashlib.sha256(f"{salt}{git_hash}".encode()).hexdigest()
```
```python
protocol CachePolicy()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > CachePolicy > Methods > get_version()** | |
### get_version()
```python
def get_version(
salt: str,
params: flyte._cache.cache.VersionParameters,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `salt` | `str` | |
| `params` | `flyte._cache.cache.VersionParameters` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/cron ===
# Cron
**Package:** `flyte`
Cron-based automation schedule for use with `Trigger`.
Cron expressions use the standard five-field format:
`minute hour day-of-month month day-of-week`
Common patterns:
- `"0 * * * *"` β every hour (at minute 0)
- `"0 0 * * *"` β daily at midnight
- `"0 0 * * 1"` β weekly on Monday at midnight
- `"0 0 1 * *"` β monthly on the 1st at midnight
- `"*/5 * * * *"` β every 5 minutes
Example:
```python
my_trigger = flyte.Trigger(
name="my_cron_trigger",
automation=flyte.Cron("0 * * * *"), # Runs every hour
description="A trigger that runs every hour",
)
```
## Parameters
```python
class Cron(
expression: str,
timezone: Timezone,
)
```
| Parameter | Type | Description |
|-|-|-|
| `expression` | `str` | Cron expression string (e.g., `"0 * * * *"`). |
| `timezone` | `Timezone` | Timezone for the cron schedule (default `"UTC"`). One of the standard timezone values (e.g., `"US/Eastern"`, `"Europe/London"`). Note that DST transitions may cause skipped or duplicated runs. |
## Properties
| Property | Type | Description |
|-|-|-|
| `timezone_expression` | `None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/device ===
# Device
**Package:** `flyte`
Represents a device type, its quantity and partition if applicable.
param device: The type of device (e.g., "T4", "A100").
param quantity: The number of devices of this type.
param partition: The partition of the device (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus).
## Parameters
```python
class Device(
quantity: int,
device_class: typing.Literal['GPU', 'TPU', 'NEURON', 'AMD_GPU', 'HABANA_GAUDI'],
device: str | None,
partition: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `quantity` | `int` | |
| `device_class` | `typing.Literal['GPU', 'TPU', 'NEURON', 'AMD_GPU', 'HABANA_GAUDI']` | |
| `device` | `str \| None` | |
| `partition` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/environment ===
# Environment
**Package:** `flyte`
## Parameters
```python
class Environment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto'], None],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of the environment |
| `depends_on` | `List[Environment]` | Environment dependencies to hint, so when you deploy the environment, the dependencies are also deployed. This is useful when you have a set of environments that depend on each other. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Pod template to use for the environment. |
| `description` | `Optional[str]` | Description of the environment. |
| `secrets` | `Optional[SecretRequest]` | Secrets to inject into the environment. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables to set for the environment. |
| `resources` | `Optional[Resources]` | Resources to allocate for the environment. |
| `interruptible` | `bool` | Whether the environment is interruptible and can be scheduled on spot/preemptible instances |
| `image` | `Union[str, Image, Literal['auto'], None]` | Docker image to use for the environment. If set to "auto", will use the default image. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Environment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte > Environment > Methods > clone_with()** | |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
kwargs: **kwargs,
) -> Environment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/fixedrate ===
# FixedRate
**Package:** `flyte`
Fixed-rate (interval-based) automation schedule for use with `Trigger`.
Unlike `Cron`, which runs at specific clock times, `FixedRate` runs at a
consistent interval regardless of clock time.
Example:
```python
my_trigger = flyte.Trigger(
name="my_fixed_rate_trigger",
automation=flyte.FixedRate(60), # Runs every 60 minutes
description="A trigger that runs every hour",
)
```
## Parameters
```python
class FixedRate(
interval_minutes: int,
start_time: datetime | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `interval_minutes` | `int` | Interval between trigger activations, in minutes (e.g., `60` for hourly, `1440` for daily). |
| `start_time` | `datetime \| None` | Optional start time for the first trigger. Subsequent triggers follow the interval from this point. If not set, the first trigger occurs `interval_minutes` after deployment/activation. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/image ===
# Image
**Package:** `flyte`
Container image specification built using a fluent, two-step pattern:
1. Create a base image with a `from_*` constructor
2. Customize with `with_*` methods (each returns a new `Image`)
Example:
```python
image = (
flyte.Image.from_debian_base(python="3.12")
.with_pip_packages("pandas", "scikit-learn")
.with_apt_packages("curl", "git")
)
```
**Base constructors** (`from_*`):
- `from_debian_base()` β Debian-based image with a specified Python version
- `from_base()` β Any base image by name (e.g., `"python:3.12-slim"`)
- `from_uv_script()` β Image from a `uv`-compatible script with inline dependencies
- `from_dockerfile()` β Image from a custom Dockerfile
- `from_ref_name()` β Reference to a pre-built image by name
**Customization methods** (`with_*`):
- `with_pip_packages()` β Add pip packages
- `with_apt_packages()` β Add system packages via apt-get
- `with_commands()` β Run arbitrary shell commands
- `with_env_vars()` β Set environment variables
- `with_requirements()` β Install from a requirements.txt file
- `with_uv_project()` β Install from a uv/pyproject.toml project
- `with_poetry_project()` β Install from a Poetry project
- `with_source_folder()` β Include a source directory
- `with_source_file()` β Include a single source file
- `with_code_bundle()` β Include a code bundle
- `with_workdir()` β Set the working directory
- `with_dockerignore()` β Add a .dockerignore
- `with_local_v2()` β Configure for local v2 execution
## Parameters
```python
class Image(
base_image: Optional[str],
dockerfile: Optional[Path],
registry: Optional[str],
name: Optional[str],
platform: Tuple[Architecture, ...],
python_version: Tuple[int, int],
extendable: bool,
_ref_name: Optional[str],
_layers: Tuple[Layer, ...],
_image_registry_secret: Optional[Secret],
)
```
| Parameter | Type | Description |
|-|-|-|
| `base_image` | `Optional[str]` | |
| `dockerfile` | `Optional[Path]` | |
| `registry` | `Optional[str]` | |
| `name` | `Optional[str]` | |
| `platform` | `Tuple[Architecture, ...]` | |
| `python_version` | `Tuple[int, int]` | |
| `extendable` | `bool` | |
| `_ref_name` | `Optional[str]` | |
| `_layers` | `Tuple[Layer, ...]` | |
| `_image_registry_secret` | `Optional[Secret]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `uri` | `None` | Returns the URI of the image in the format <registry>/<name>:<tag> |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Image > Methods > clone()** | Use this method to clone the current image and change the registry and name. |
| **Flyte SDK > Packages > flyte > Image > Methods > from_base()** | Use this method to start with a pre-built base image. |
| **Flyte SDK > Packages > flyte > Image > Methods > from_debian_base()** | Use this method to start using the default base image, built from this library's base Dockerfile. |
| **Flyte SDK > Packages > flyte > Image > Methods > from_dockerfile()** | Use this method to create a new image with the specified dockerfile. |
| **Flyte SDK > Packages > flyte > Image > Methods > from_ref_name()** | |
| **Flyte SDK > Packages > flyte > Image > Methods > from_uv_script()** | Use this method to create a new image with the specified uv script. |
| **Flyte SDK > Packages > flyte > Image > validate()** | |
| **Flyte SDK > Packages > flyte > Image > with_apt_packages()** | Use this method to create a new image with the specified apt packages layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_code_bundle()** | Configure this image to automatically copy source code from root_dir. |
| **Flyte SDK > Packages > flyte > Image > with_commands()** | Use this method to create a new image with the specified commands layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_dockerignore()** | |
| **Flyte SDK > Packages > flyte > Image > with_env_vars()** | Use this method to create a new image with the specified environment variables layered on top of. |
| **Flyte SDK > Packages > flyte > Image > with_local_v2()** | Use this method to create a new image with the local v2 builder. |
| **Flyte SDK > Packages > flyte > Image > with_pip_packages()** | Use this method to create a new image with the specified pip packages layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_poetry_project()** | Use this method to create a new image with the specified pyproject. |
| **Flyte SDK > Packages > flyte > Image > with_requirements()** | Use this method to create a new image with the specified requirements file layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_source_file()** | Use this method to create a new image with the specified local file(s) layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_source_folder()** | Use this method to create a new image with the specified local directory layered on top of the current image. |
| **Flyte SDK > Packages > flyte > Image > with_uv_project()** | Use this method to create a new image with the specified uv. |
| **Flyte SDK > Packages > flyte > Image > with_workdir()** | Use this method to create a new image with the specified working directory. |
### clone()
```python
def clone(
registry: Optional[str],
registry_secret: Optional[str | Secret],
name: Optional[str],
base_image: Optional[str],
python_version: Optional[Tuple[int, int]],
addl_layer: Optional[Layer],
extendable: Optional[bool],
) -> Image
```
Use this method to clone the current image and change the registry and name
| Parameter | Type | Description |
|-|-|-|
| `registry` | `Optional[str]` | Registry to use for the image |
| `registry_secret` | `Optional[str \| Secret]` | Secret to use to pull/push the private image. |
| `name` | `Optional[str]` | Name of the image |
| `base_image` | `Optional[str]` | Base image to use for the image |
| `python_version` | `Optional[Tuple[int, int]]` | Python version for the image, if not specified, will use the current Python version |
| `addl_layer` | `Optional[Layer]` | Additional layer to add to the image. This will be added to the end of the layers. |
| `extendable` | `Optional[bool]` | Whether the image is extendable by other images. If True, the image can be used as a base image for other images, and additional layers can be added on top of it. If False, the image cannot be used as a base image for other images, and additional layers cannot be added on top of it. If None (default), defaults to False for safety. |
### from_base()
```python
def from_base(
image_uri: str,
) -> Image
```
Use this method to start with a pre-built base image. This image must already exist in the registry of course.
| Parameter | Type | Description |
|-|-|-|
| `image_uri` | `str` | The full URI of the image, in the format <registry>/<name> |
### from_debian_base()
```python
def from_debian_base(
python_version: Optional[Tuple[int, int]],
flyte_version: Optional[str],
install_flyte: bool,
registry: Optional[str],
registry_secret: Optional[str | Secret],
name: Optional[str],
platform: Optional[Tuple[Architecture, ...]],
) -> Image
```
Use this method to start using the default base image, built from this library's base Dockerfile
Default images are multi-arch amd/arm64
| Parameter | Type | Description |
|-|-|-|
| `python_version` | `Optional[Tuple[int, int]]` | If not specified, will use the current Python version |
| `flyte_version` | `Optional[str]` | Flyte version to use |
| `install_flyte` | `bool` | If True, will install the flyte library in the image |
| `registry` | `Optional[str]` | Registry to use for the image |
| `registry_secret` | `Optional[str \| Secret]` | Secret to use to pull/push the private image. |
| `name` | `Optional[str]` | Name of the image if you want to override the default name |
| `platform` | `Optional[Tuple[Architecture, ...]]` | Platform to use for the image, default is linux/amd64, use tuple for multiple values Example: ("linux/amd64", "linux/arm64") |
**Returns:** Image
### from_dockerfile()
```python
def from_dockerfile(
file: Path,
registry: str,
name: str,
platform: Union[Architecture, Tuple[Architecture, ...], None],
) -> Image
```
Use this method to create a new image with the specified dockerfile. Note you cannot use additional layers
after this, as the system doesn't attempt to parse/understand the Dockerfile, and what kind of setup it has
(python version, uv vs poetry, etc), so please put all logic into the dockerfile itself.
Also since Python sees paths as from the calling directory, please use Path objects with absolute paths. The
context for the builder will be the directory where the dockerfile is located.
| Parameter | Type | Description |
|-|-|-|
| `file` | `Path` | path to the dockerfile |
| `registry` | `str` | registry to use for the image |
| `name` | `str` | name of the image |
| `platform` | `Union[Architecture, Tuple[Architecture, ...], None]` | architecture to use for the image, default is linux/amd64, use tuple for multiple values Example: ("linux/amd64", "linux/arm64") |
### from_ref_name()
```python
def from_ref_name(
name: str,
) -> Image
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
### from_uv_script()
```python
def from_uv_script(
script: Path | str,
name: str,
registry: str | None,
registry_secret: Optional[str | Secret],
python_version: Optional[Tuple[int, int]],
index_url: Optional[str],
extra_index_urls: Union[str, List[str], Tuple[str, ...], None],
pre: bool,
extra_args: Optional[str],
platform: Optional[Tuple[Architecture, ...]],
secret_mounts: Optional[SecretRequest],
) -> Image
```
Use this method to create a new image with the specified uv script.
It uses the header of the script to determine the python version, dependencies to install.
The script must be a valid uv script, otherwise an error will be raised.
Usually the header of the script will look like this:
Example:
```python
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = ["httpx"]
# ///
```
For more information on the uv script format, see the documentation:
[UV: Declaring script dependencies](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies)
| Parameter | Type | Description |
|-|-|-|
| `script` | `Path \| str` | path to the uv script |
| `name` | `str` | name of the image |
| `registry` | `str \| None` | registry to use for the image |
| `registry_secret` | `Optional[str \| Secret]` | Secret to use to pull/push the private image. |
| `python_version` | `Optional[Tuple[int, int]]` | Python version for the image, if not specified, will use the current Python version |
| `index_url` | `Optional[str]` | index url to use for pip install, default is None |
| `extra_index_urls` | `Union[str, List[str], Tuple[str, ...], None]` | extra index urls to use for pip install, default is True |
| `pre` | `bool` | whether to allow pre-release versions, default is False |
| `extra_args` | `Optional[str]` | extra arguments to pass to pip install, default is None |
| `platform` | `Optional[Tuple[Architecture, ...]]` | architecture to use for the image, default is linux/amd64, use tuple for multiple values |
| `secret_mounts` | `Optional[SecretRequest]` | |
**Returns:** Image
### validate()
```python
def validate()
```
### with_apt_packages()
```python
def with_apt_packages(
packages: str,
secret_mounts: Optional[SecretRequest],
) -> Image
```
Use this method to create a new image with the specified apt packages layered on top of the current image
| Parameter | Type | Description |
|-|-|-|
| `packages` | `str` | list of apt packages to install |
| `secret_mounts` | `Optional[SecretRequest]` | list of secret mounts to use for the build process. |
**Returns:** Image
### with_code_bundle()
```python
def with_code_bundle(
copy_style: Literal['loaded_modules', 'all'],
dst: str,
) -> Image
```
Configure this image to automatically copy source code from root_dir
when the runner's copy_style is "none".
When the runner's copy_style is not "none", this is a no-op.
| Parameter | Type | Description |
|-|-|-|
| `copy_style` | `Literal['loaded_modules', 'all']` | Which files to copy into the image. "loaded_modules" copies only imported Python modules. "all" copies all files from root_dir. |
| `dst` | `str` | Destination directory in the container. Defaults to working dir. |
**Returns:** Image
### with_commands()
```python
def with_commands(
commands: List[str],
secret_mounts: Optional[SecretRequest],
) -> Image
```
Use this method to create a new image with the specified commands layered on top of the current image
Be sure not to use RUN in your command.
| Parameter | Type | Description |
|-|-|-|
| `commands` | `List[str]` | list of commands to run |
| `secret_mounts` | `Optional[SecretRequest]` | list of secret mounts to use for the build process. |
**Returns:** Image
### with_dockerignore()
```python
def with_dockerignore(
path: Path,
) -> Image
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `Path` | |
### with_env_vars()
```python
def with_env_vars(
env_vars: Dict[str, str],
) -> Image
```
Use this method to create a new image with the specified environment variables layered on top of
the current image. Cannot be used in conjunction with conda
| Parameter | Type | Description |
|-|-|-|
| `env_vars` | `Dict[str, str]` | dictionary of environment variables to set |
**Returns:** Image
### with_local_v2()
```python
def with_local_v2()
```
Use this method to create a new image with the local v2 builder
This will override any existing builder
**Returns:** Image
### with_pip_packages()
```python
def with_pip_packages(
packages: str,
index_url: Optional[str],
extra_index_urls: Union[str, List[str], Tuple[str, ...], None],
pre: bool,
extra_args: Optional[str],
secret_mounts: Optional[SecretRequest],
) -> Image
```
Use this method to create a new image with the specified pip packages layered on top of the current image
Cannot be used in conjunction with conda
Example:
```python
@flyte.task(image=(flyte.Image.from_debian_base().with_pip_packages("requests", "numpy")))
def my_task(x: int) -> int:
import numpy as np
return np.sum([x, 1])
```
To mount secrets during the build process to download private packages, you can use the `secret_mounts`.
In the below example, "GITHUB_PAT" will be mounted as env var "GITHUB_PAT",
and "apt-secret" will be mounted at /etc/apt/apt-secret.
Example:
```python
private_package = "git+https://$GITHUB_PAT@github.com/flyteorg/flytex.git@2e20a2acebfc3877d84af643fdd768edea41d533"
@flyte.task(
image=(
flyte.Image.from_debian_base()
.with_pip_packages("private_package", secret_mounts=[Secret(key="GITHUB_PAT")])
.with_apt_packages("git", secret_mounts=[Secret(key="apt-secret", mount="/etc/apt/apt-secret")])
)
def my_task(x: int) -> int:
import numpy as np
return np.sum([x, 1])
```
| Parameter | Type | Description |
|-|-|-|
| `packages` | `str` | list of pip packages to install, follows pip install syntax |
| `index_url` | `Optional[str]` | index url to use for pip install, default is None |
| `extra_index_urls` | `Union[str, List[str], Tuple[str, ...], None]` | extra index urls to use for pip install, default is None |
| `pre` | `bool` | whether to allow pre-release versions, default is False |
| `extra_args` | `Optional[str]` | extra arguments to pass to pip install, default is None |
| `secret_mounts` | `Optional[SecretRequest]` | list of secret to mount for the build process. |
**Returns:** Image
### with_poetry_project()
```python
def with_poetry_project(
pyproject_file: str | Path,
poetry_lock: Path | None,
extra_args: Optional[str],
secret_mounts: Optional[SecretRequest],
project_install_mode: typing.Literal['dependencies_only', 'install_project'],
)
```
Use this method to create a new image with the specified pyproject.toml layered on top of the current image.
Must have a corresponding pyproject.toml file in the same directory.
Cannot be used in conjunction with conda.
By default, this method copies the entire project into the image,
including files such as pyproject.toml, poetry.lock, and the src/ directory.
If you prefer not to install the current project, you can pass through `extra_args`
`--no-root`. In this case, the image builder will only copy pyproject.toml and poetry.lock
into the image.
| Parameter | Type | Description |
|-|-|-|
| `pyproject_file` | `str \| Path` | Path to the pyproject.toml file. A poetry.lock file must exist in the same directory unless `poetry_lock` is explicitly provided. |
| `poetry_lock` | `Path \| None` | Path to the poetry.lock file. If not specified, the default is the file named 'poetry.lock' in the same directory as `pyproject_file` (pyproject.parent / "poetry.lock"). |
| `extra_args` | `Optional[str]` | Extra arguments to pass through to the package installer/resolver, default is None. |
| `secret_mounts` | `Optional[SecretRequest]` | Secrets to make available during dependency resolution/build (e.g., private indexes). |
| `project_install_mode` | `typing.Literal['dependencies_only', 'install_project']` | whether to install the project as a package or only dependencies, default is "dependencies_only" |
**Returns:** Image
### with_requirements()
```python
def with_requirements(
file: str | Path,
index_url: Optional[str],
extra_index_urls: Union[str, List[str], Tuple[str, ...], None],
pre: bool,
extra_args: Optional[str],
secret_mounts: Optional[SecretRequest],
) -> Image
```
Use this method to create a new image with the specified requirements file layered on top of the current image
Cannot be used in conjunction with conda
| Parameter | Type | Description |
|-|-|-|
| `file` | `str \| Path` | path to the requirements file, must be a .txt file |
| `index_url` | `Optional[str]` | index url to use for pip install, default is None |
| `extra_index_urls` | `Union[str, List[str], Tuple[str, ...], None]` | extra index urls to use for pip install, default is None |
| `pre` | `bool` | if True, install pre-release packages, default is False |
| `extra_args` | `Optional[str]` | extra arguments to pass to pip install, default is None |
| `secret_mounts` | `Optional[SecretRequest]` | list of secret to mount for the build process. |
### with_source_file()
```python
def with_source_file(
src: typing.Union[Path, typing.List[Path]],
dst: str,
) -> Image
```
Use this method to create a new image with the specified local file(s) layered on top of the current image.
If dest is not specified, it will be copied to the working directory of the image
| Parameter | Type | Description |
|-|-|-|
| `src` | `typing.Union[Path, typing.List[Path]]` | file or list of files from the build context to be copied |
| `dst` | `str` | destination folder in the image |
**Returns:** Image
### with_source_folder()
```python
def with_source_folder(
src: Path,
dst: str,
copy_contents_only: bool,
) -> Image
```
Use this method to create a new image with the specified local directory layered on top of the current image.
If dest is not specified, it will be copied to the working directory of the image
| Parameter | Type | Description |
|-|-|-|
| `src` | `Path` | root folder of the source code from the build context to be copied |
| `dst` | `str` | destination folder in the image |
| `copy_contents_only` | `bool` | If True, will copy the contents of the source folder to the destination folder, instead of the folder itself. Default is False. |
**Returns:** Image
### with_uv_project()
```python
def with_uv_project(
pyproject_file: str | Path,
uvlock: Path | None,
index_url: Optional[str],
extra_index_urls: Union[List[str], Tuple[str, ...], None],
pre: bool,
extra_args: Optional[str],
secret_mounts: Optional[SecretRequest],
project_install_mode: typing.Literal['dependencies_only', 'install_project'],
) -> Image
```
Use this method to create a new image with the specified uv.lock file layered on top of the current image
Must have a corresponding pyproject.toml file in the same directory
Cannot be used in conjunction with conda
By default, this method copies the pyproject.toml and uv.lock files into the image.
If `project_install_mode` is "install_project", it will also copy directory
where the pyproject.toml file is located into the image.
| Parameter | Type | Description |
|-|-|-|
| `pyproject_file` | `str \| Path` | path to the pyproject.toml file |
| `uvlock` | `Path \| None` | path to the uv.lock file, if not specified, will use the default uv.lock file in the same directory as the pyproject.toml file if it exists. (pyproject.parent / uv.lock) |
| `index_url` | `Optional[str]` | index url to use for pip install, default is None |
| `extra_index_urls` | `Union[List[str], Tuple[str, ...], None]` | extra index urls to use for pip install, default is None |
| `pre` | `bool` | whether to allow pre-release versions, default is False |
| `extra_args` | `Optional[str]` | extra arguments to pass to pip install, default is None |
| `secret_mounts` | `Optional[SecretRequest]` | list of secret mounts to use for the build process. |
| `project_install_mode` | `typing.Literal['dependencies_only', 'install_project']` | whether to install the project as a package or only dependencies, default is "dependencies_only" |
**Returns:** Image
### with_workdir()
```python
def with_workdir(
workdir: str,
) -> Image
```
Use this method to create a new image with the specified working directory
This will override any existing working directory
| Parameter | Type | Description |
|-|-|-|
| `workdir` | `str` | working directory to use |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/imagebuild ===
# ImageBuild
**Package:** `flyte`
Result of an image build operation.
## Parameters
```python
class ImageBuild(
uri: str | None,
remote_run: Optional['remote.Run'],
)
```
| Parameter | Type | Description |
|-|-|-|
| `uri` | `str \| None` | The fully qualified image URI. None if the build was started asynchronously and hasn't completed yet. |
| `remote_run` | `Optional['remote.Run']` | The Run object that kicked off an image build job when using the remote builder. None when using the local builder. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/link ===
# Link
**Package:** `flyte`
```python
protocol Link()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Link > Methods > get_link()** | Returns a task log link given the action. |
### get_link()
```python
def get_link(
run_name: str,
project: str,
domain: str,
context: typing.Dict[str, str],
parent_action_name: str,
action_name: str,
pod_name: str,
kwargs,
) -> str
```
Returns a task log link given the action.
Link can have template variables that are replaced by the backend.
| Parameter | Type | Description |
|-|-|-|
| `run_name` | `str` | The name of the run. |
| `project` | `str` | The project name. |
| `domain` | `str` | The domain name. |
| `context` | `typing.Dict[str, str]` | Additional context for generating the link. |
| `parent_action_name` | `str` | The name of the parent action. |
| `action_name` | `str` | The name of the action. |
| `pod_name` | `str` | The name of the pod. |
| `kwargs` | `**kwargs` | Additional keyword arguments. |
**Returns:** The generated link.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/podtemplate ===
# PodTemplate
**Package:** `flyte`
Custom PodTemplate specification for a Task.
## Parameters
```python
class PodTemplate(
pod_spec: typing.Optional[ForwardRef('V1PodSpec')],
primary_container_name: str,
labels: typing.Optional[typing.Dict[str, str]],
annotations: typing.Optional[typing.Dict[str, str]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `pod_spec` | `typing.Optional[ForwardRef('V1PodSpec')]` | |
| `primary_container_name` | `str` | |
| `labels` | `typing.Optional[typing.Dict[str, str]]` | |
| `annotations` | `typing.Optional[typing.Dict[str, str]]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > PodTemplate > Methods > to_k8s_pod()** | |
### to_k8s_pod()
```python
def to_k8s_pod()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/resources ===
# Resources
**Package:** `flyte`
Resources such as CPU, Memory, and GPU that can be allocated to a task.
Set via `TaskEnvironment(resources=...)` or `task.override(resources=...)`.
Examples:
```python
# Simple: 1 CPU, 1 GiB memory, 1 T4 GPU
Resources(cpu=1, memory="1Gi", gpu="T4:1")
# Range: request 1 CPU (limit 2), 2 GiB memory, 8 A100 GPUs, 10 GiB disk
Resources(cpu=(1, 2), memory="2Gi", gpu="A100:8", disk="10Gi")
# Advanced GPU with partitioning
Resources(gpu=GPU(device="A100", quantity=1, partition="1g.5gb"))
# TPU
Resources(gpu=TPU(device="V5P", partition="2x2x1"))
```
## Parameters
```python
class Resources(
cpu: typing.Union[int, float, str, typing.Tuple[int | float | str, int | float | str], NoneType],
memory: typing.Union[str, typing.Tuple[str, str], NoneType],
gpu: typing.Union[typing.Literal['A10:1', 'A10:2', 'A10:3', 'A10:4', 'A10:5', 'A10:6', 'A10:7', 'A10:8', 'A10G:1', 'A10G:2', 'A10G:3', 'A10G:4', 'A10G:5', 'A10G:6', 'A10G:7', 'A10G:8', 'A100:1', 'A100:2', 'A100:3', 'A100:4', 'A100:5', 'A100:6', 'A100:7', 'A100:8', 'A100 80G:1', 'A100 80G:2', 'A100 80G:3', 'A100 80G:4', 'A100 80G:5', 'A100 80G:6', 'A100 80G:7', 'A100 80G:8', 'B200:1', 'B200:2', 'B200:3', 'B200:4', 'B200:5', 'B200:6', 'B200:7', 'B200:8', 'H100:1', 'H100:2', 'H100:3', 'H100:4', 'H100:5', 'H100:6', 'H100:7', 'H100:8', 'H200:1', 'H200:2', 'H200:3', 'H200:4', 'H200:5', 'H200:6', 'H200:7', 'H200:8', 'L4:1', 'L4:2', 'L4:3', 'L4:4', 'L4:5', 'L4:6', 'L4:7', 'L4:8', 'L40s:1', 'L40s:2', 'L40s:3', 'L40s:4', 'L40s:5', 'L40s:6', 'L40s:7', 'L40s:8', 'V100:1', 'V100:2', 'V100:3', 'V100:4', 'V100:5', 'V100:6', 'V100:7', 'V100:8', 'RTX PRO 6000:1', 'GB10:1', 'T4:1', 'T4:2', 'T4:3', 'T4:4', 'T4:5', 'T4:6', 'T4:7', 'T4:8', 'Trn1:1', 'Trn1:4', 'Trn1:8', 'Trn1:16', 'Trn1n:1', 'Trn1n:4', 'Trn1n:8', 'Trn1n:16', 'Trn2:1', 'Trn2:4', 'Trn2:8', 'Trn2:16', 'Trn2u:1', 'Trn2u:4', 'Trn2u:8', 'Trn2u:16', 'Inf1:1', 'Inf1:2', 'Inf1:3', 'Inf1:4', 'Inf1:5', 'Inf1:6', 'Inf1:7', 'Inf1:8', 'Inf1:9', 'Inf1:10', 'Inf1:11', 'Inf1:12', 'Inf1:13', 'Inf1:14', 'Inf1:15', 'Inf1:16', 'Inf2:1', 'Inf2:2', 'Inf2:3', 'Inf2:4', 'Inf2:5', 'Inf2:6', 'Inf2:7', 'Inf2:8', 'Inf2:9', 'Inf2:10', 'Inf2:11', 'Inf2:12', 'MI100:1', 'MI210:1', 'MI250:1', 'MI250X:1', 'MI300A:1', 'MI300X:1', 'MI325X:1', 'MI350X:1', 'MI355X:1', 'Gaudi1:1'], int, flyte._resources.Device, NoneType],
disk: typing.Optional[str],
shm: typing.Union[str, typing.Literal['auto'], NoneType],
)
```
| Parameter | Type | Description |
|-|-|-|
| `cpu` | `typing.Union[int, float, str, typing.Tuple[int \| float \| str, int \| float \| str], NoneType]` | CPU allocation. Accepts several formats - `int` or `float`: CPU cores (e.g., `1`, `0.5`) - `str`: Kubernetes-style (e.g., `"500m"` for 0.5 cores, `"2"` for 2 cores) - `tuple`: Request/limit range (e.g., `(1, 4)` requests 1 core, limits to 4) |
| `memory` | `typing.Union[str, typing.Tuple[str, str], NoneType]` | Memory allocation using Kubernetes unit conventions - Binary units: `"512Mi"`, `"1Gi"`, `"4Gi"` - Decimal units: `"500M"`, `"1G"` - `tuple`: Request/limit range (e.g., `("1Gi", "4Gi")`) |
| `gpu` | `typing.Union[typing.Literal['A10:1', 'A10:2', 'A10:3', 'A10:4', 'A10:5', 'A10:6', 'A10:7', 'A10:8', 'A10G:1', 'A10G:2', 'A10G:3', 'A10G:4', 'A10G:5', 'A10G:6', 'A10G:7', 'A10G:8', 'A100:1', 'A100:2', 'A100:3', 'A100:4', 'A100:5', 'A100:6', 'A100:7', 'A100:8', 'A100 80G:1', 'A100 80G:2', 'A100 80G:3', 'A100 80G:4', 'A100 80G:5', 'A100 80G:6', 'A100 80G:7', 'A100 80G:8', 'B200:1', 'B200:2', 'B200:3', 'B200:4', 'B200:5', 'B200:6', 'B200:7', 'B200:8', 'H100:1', 'H100:2', 'H100:3', 'H100:4', 'H100:5', 'H100:6', 'H100:7', 'H100:8', 'H200:1', 'H200:2', 'H200:3', 'H200:4', 'H200:5', 'H200:6', 'H200:7', 'H200:8', 'L4:1', 'L4:2', 'L4:3', 'L4:4', 'L4:5', 'L4:6', 'L4:7', 'L4:8', 'L40s:1', 'L40s:2', 'L40s:3', 'L40s:4', 'L40s:5', 'L40s:6', 'L40s:7', 'L40s:8', 'V100:1', 'V100:2', 'V100:3', 'V100:4', 'V100:5', 'V100:6', 'V100:7', 'V100:8', 'RTX PRO 6000:1', 'GB10:1', 'T4:1', 'T4:2', 'T4:3', 'T4:4', 'T4:5', 'T4:6', 'T4:7', 'T4:8', 'Trn1:1', 'Trn1:4', 'Trn1:8', 'Trn1:16', 'Trn1n:1', 'Trn1n:4', 'Trn1n:8', 'Trn1n:16', 'Trn2:1', 'Trn2:4', 'Trn2:8', 'Trn2:16', 'Trn2u:1', 'Trn2u:4', 'Trn2u:8', 'Trn2u:16', 'Inf1:1', 'Inf1:2', 'Inf1:3', 'Inf1:4', 'Inf1:5', 'Inf1:6', 'Inf1:7', 'Inf1:8', 'Inf1:9', 'Inf1:10', 'Inf1:11', 'Inf1:12', 'Inf1:13', 'Inf1:14', 'Inf1:15', 'Inf1:16', 'Inf2:1', 'Inf2:2', 'Inf2:3', 'Inf2:4', 'Inf2:5', 'Inf2:6', 'Inf2:7', 'Inf2:8', 'Inf2:9', 'Inf2:10', 'Inf2:11', 'Inf2:12', 'MI100:1', 'MI210:1', 'MI250:1', 'MI250X:1', 'MI300A:1', 'MI300X:1', 'MI325X:1', 'MI350X:1', 'MI355X:1', 'Gaudi1:1'], int, flyte._resources.Device, NoneType]` | GPU, TPU, or other accelerator allocation. Accepts - `int`: GPU count, any available type (e.g., `1`, `4`) - `str`: Type and quantity (e.g., `"T4:1"`, `"A100:2"`, `"H100:8"`) - `Device`: Advanced config via `GPU()`, `TPU()`, or `Device()` for partitioning and custom device types. See `GPU`, `TPU`, `Device` for details. Supported GPU types include T4, L4, L40s, A10, A10G, A100, A100 80G, B200, H100, H200, V100. GPU partitioning (MIG) is available on A100, A100 80G, H100, and H200. |
| `disk` | `typing.Optional[str]` | Ephemeral disk storage as a string with Kubernetes units (e.g., `"10Gi"`, `"100Gi"`, `"1Ti"`). Automatically cleaned up when the task completes. |
| `shm` | `typing.Union[str, typing.Literal['auto'], NoneType]` | Shared memory (`/dev/shm`) allocation. Useful for ML data loading and inter-process communication: - `str`: Size with units (e.g., `"1Gi"`, `"16Gi"`) - `"auto"`: Set to the maximum shared memory available on the node |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Resources > Methods > get_device()** | Get the accelerator string for the task. |
| **Flyte SDK > Packages > flyte > Resources > Methods > get_shared_memory()** | Get the shared memory string for the task. |
### get_device()
```python
def get_device()
```
Get the accelerator string for the task.
Default cloud provider labels typically use the following values: `1g.5gb`, `2g.10gb`, etc.
**Returns:** If GPUs are requested, return a tuple of the device name, and potentially a partition string.
### get_shared_memory()
```python
def get_shared_memory()
```
Get the shared memory string for the task.
**Returns:** The shared memory string.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/retrystrategy ===
# RetryStrategy
**Package:** `flyte`
Retry strategy for the task or task environment. Retry strategy is optional or can be a simple number of retries.
Example:
- This will retry the task 5 times.
```
@task(retries=5)
def my_task():
pass
```
- This will retry the task 5 times with a maximum backoff of 10 seconds and a backoff factor of 2.
```
@task(retries=RetryStrategy(count=5))
def my_task():
pass
```
## Parameters
```python
class RetryStrategy(
count: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `count` | `int` | The number of retries. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/reusepolicy ===
# ReusePolicy
**Package:** `flyte`
Configure a task environment for container reuse across multiple task invocations.
When environment creation is expensive relative to task runtime, reusable containers
keep a pool of warm containers ready, avoiding cold-start overhead. The Python process
may be reused by subsequent task invocations.
Total concurrent capacity is `max_replicas * concurrency`. For example,
`ReusePolicy(replicas=(1, 3), concurrency=2)` supports up to 6 concurrent tasks.
Caution: The environment is shared across invocations β manage memory and resources carefully.
Example:
```python
env = flyte.TaskEnvironment(
name="fast_env",
reusable=flyte.ReusePolicy(replicas=(1, 3), concurrency=2),
)
```
## Parameters
```python
class ReusePolicy(
replicas: typing.Union[int, typing.Tuple[int, int]],
idle_ttl: typing.Union[int, datetime.timedelta],
concurrency: int,
scaledown_ttl: typing.Union[int, datetime.timedelta],
)
```
| Parameter | Type | Description |
|-|-|-|
| `replicas` | `typing.Union[int, typing.Tuple[int, int]]` | Number of container replicas to maintain. - `int`: Fixed replica count, always running (e.g., `replicas=3`). - `tuple(min, max)`: Auto-scaling range (e.g., `replicas=(1, 5)`). Scales between min and max based on demand. Default is `2`. A minimum of 2 replicas is recommended to avoid starvation when the parent task occupies one replica. |
| `idle_ttl` | `typing.Union[int, datetime.timedelta]` | Environment-level idle timeout β shuts down **all** replicas when the entire environment has been idle for this duration. Specified as seconds (`int`) or `timedelta`. Minimum 30 seconds. Default is 30 seconds. |
| `concurrency` | `int` | Maximum concurrent tasks per replica. Values greater than 1 are only supported for `async` tasks. Default is `1`. |
| `scaledown_ttl` | `typing.Union[int, datetime.timedelta]` | Per-replica scale-down delay β minimum time to wait before removing an **individual** idle replica. Prevents rapid scale-down when tasks arrive in bursts. Specified as seconds (`int`) or `timedelta`. Default is 30 seconds. Note the distinction: `idle_ttl` controls when the whole environment shuts down; `scaledown_ttl` controls when individual replicas are removed during auto-scaling. |
## Properties
| Property | Type | Description |
|-|-|-|
| `max_replicas` | `None` | Returns the maximum number of replicas. |
| `min_replicas` | `None` | Returns the minimum number of replicas. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > ReusePolicy > Methods > get_scaledown_ttl()** | Returns the scaledown TTL as a timedelta. |
### get_scaledown_ttl()
```python
def get_scaledown_ttl()
```
Returns the scaledown TTL as a timedelta. If scaledown_ttl is not set, returns None.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/secret ===
# Secret
**Package:** `flyte`
Secrets are used to inject sensitive information into tasks or image build context.
Secrets can be mounted as environment variables or files.
The secret key is the name of the secret in the secret store. The group is optional and maybe used with some
secret stores to organize secrets. The as_env_var is an optional parameter that can be used to specify the
name of the environment variable that the secret should be mounted as.
Example:
```python
@task(secrets="my-secret")
async def my_task():
# This will be set to the value of the secret. Note: The env var is always uppercase, and - is replaced with _.
os.environ["MY_SECRET"]
@task(secrets=Secret("my-openai-api-key", as_env_var="OPENAI_API_KEY"))
async def my_task2():
os.environ["OPENAI_API_KEY"]
```
TODO: Add support for secret versioning (some stores) and secret groups (some stores) and mounting as files.
## Parameters
```python
class Secret(
key: str,
group: typing.Optional[str],
mount: pathlib.Path | None,
as_env_var: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `key` | `str` | The name of the secret in the secret store. |
| `group` | `typing.Optional[str]` | The group of the secret in the secret store. |
| `mount` | `pathlib.Path \| None` | For now, the only supported mount path is "/etc/flyte/secrets". TODO: support arbitrary mount paths. Today only "/etc/flyte/secrets" is supported |
| `as_env_var` | `typing.Optional[str]` | The name of the environment variable that the secret should be mounted as. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Secret > Methods > stable_hash()** | Deterministic, process-independent hash (as hex string). |
### stable_hash()
```python
def stable_hash()
```
Deterministic, process-independent hash (as hex string).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/taskenvironment ===
# TaskEnvironment
**Package:** `flyte`
Define an execution environment for a set of tasks.
Task configuration in Flyte has three levels (most general to most specific):
1. **TaskEnvironment** β sets defaults for all tasks in the environment
2. **@env.task decorator** β overrides per-task settings
3. **task.override()** β overrides at invocation time
For shared parameters, the more specific level overrides the more general one.
Example:
```python
env = flyte.TaskEnvironment(
name="my_env",
image=flyte.Image.from_debian_base(python="3.12").with_pip_packages("pandas"),
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
@env.task
async def my_task():
pass
```
## Parameters
```python
class TaskEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto'], None],
cache: CacheRequest,
reusable: ReusePolicy | None,
plugin_config: Optional[Any],
queue: Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of the environment (required). Must be snake_case or kebab-case. TaskEnvironment level only. |
| `depends_on` | `List[Environment]` | List of other environments this one depends on. Used at deploy time to ensure dependencies are also deployed. TaskEnvironment level only. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Kubernetes pod template for advanced configuration (sidecars, volumes, etc.). Also settable in `@env.task` and `task.override`. |
| `description` | `Optional[str]` | Human-readable description (max 255 characters). TaskEnvironment level only. |
| `secrets` | `Optional[SecretRequest]` | Secrets to inject. Overridable via `task.override(secrets=...)` when not using reusable containers. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables as `dict[str, str]`. Overridable via `task.override(env_vars=...)` when not using reusable containers. |
| `resources` | `Optional[Resources]` | Compute resources (CPU, memory, GPU, disk). Overridable via `task.override(resources=...)` when not using reusable containers. |
| `interruptible` | `bool` | Whether tasks can run on spot/preemptible instances. Also settable in `@env.task` and `task.override`. |
| `image` | `Union[str, Image, Literal['auto'], None]` | Docker image for the environment. Can be a string (image URI), an `Image` object, or `"auto"` to use the default image. TaskEnvironment level only. |
| `cache` | `CacheRequest` | Cache policy β `"auto"`, `"override"`, `"disable"`, or a `Cache` object. Also settable in `@env.task(cache=...)` and `task.override(cache=...)`. |
| `reusable` | `ReusePolicy \| None` | `ReusePolicy` for container reuse. Also overridable via `task.override(reusable=...)`. |
| `plugin_config` | `Optional[Any]` | Plugin configuration for custom task types (e.g., Ray, Spark). Cannot be combined with `reusable`. TaskEnvironment level only. |
| `queue` | `Optional[str]` | Queue name for scheduling. Also settable in `@env.task` and `task.override`. |
## Properties
| Property | Type | Description |
|-|-|-|
| `sandbox` | `None` | Access the sandbox namespace for creating sandboxed tasks. |
| `tasks` | `None` | Get all tasks defined in the environment. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > TaskEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte > TaskEnvironment > Methods > clone_with()** | Clone the TaskEnvironment with new parameters. |
| **Flyte SDK > Packages > flyte > TaskEnvironment > Methods > from_task()** | Create a TaskEnvironment from a list of tasks. |
| **Flyte SDK > Packages > flyte > TaskEnvironment > Methods > task()** | Decorate a function to be a task. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> TaskEnvironment
```
Clone the TaskEnvironment with new parameters.
Besides the base environment parameters, you can override kwargs like `cache`, `reusable`, etc.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the environment. |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | The image to use for the environment. |
| `resources` | `Optional[Resources]` | The resources to allocate for the environment. |
| `env_vars` | `Optional[Dict[str, str]]` | The environment variables to set for the environment. |
| `secrets` | `Optional[SecretRequest]` | The secrets to inject into the environment. |
| `depends_on` | `Optional[List[Environment]]` | The environment dependencies to hint, so when you deploy the environment, the dependencies are also deployed. This is useful when you have a set of environments that depend on each other. |
| `description` | `Optional[str]` | The description of the environment. |
| `interruptible` | `Optional[bool]` | Whether the environment is interruptible and can be scheduled on spot/preemptible instances. |
| `kwargs` | `**kwargs` | Additional parameters to override the environment (e.g., cache, reusable, plugin_config). |
### from_task()
```python
def from_task(
name: str,
tasks: TaskTemplate,
depends_on: Optional[List['Environment']],
) -> TaskEnvironment
```
Create a TaskEnvironment from a list of tasks. All tasks should have the same image or no Image defined.
Similarity of Image is determined by the python reference, not by value.
If images are different, an error is raised. If no image is defined, the image is set to "auto".
For any other tasks that need to be use these tasks, the returned environment can be used in the `depends_on`
attribute of the other TaskEnvironment.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the environment. |
| `tasks` | `TaskTemplate` | The list of tasks to create the environment from. |
| `depends_on` | `Optional[List['Environment']]` | Optional list of environments that this environment depends on. |
**Returns:** The created TaskEnvironment.
**Raises**
| Exception | Description |
|-|-|
| `ValueError` | If tasks are assigned to multiple environments or have different images. |
### task()
```python
def task(
_func: F | None,
short_name: Optional[str],
cache: CacheRequest | None,
retries: Union[int, RetryStrategy],
timeout: Union[timedelta, int],
docs: Optional[Documentation],
pod_template: Optional[Union[str, PodTemplate]],
report: bool,
interruptible: bool | None,
max_inline_io_bytes: int,
queue: Optional[str],
triggers: Tuple[Trigger, ...] | Trigger,
links: Tuple[Link, ...] | Link,
task_resolver: Any | None,
) -> Callable[[F], AsyncFunctionTaskTemplate[P, R, F]] | AsyncFunctionTaskTemplate[P, R, F]
```
Decorate a function to be a task.
| Parameter | Type | Description |
|-|-|-|
| `_func` | `F \| None` | Optional The function to decorate. If not provided, the decorator will return a callable that accepts a function to be decorated. |
| `short_name` | `Optional[str]` | Optional A friendly name for the task (defaults to the function name) |
| `cache` | `CacheRequest \| None` | Optional The cache policy for the task, defaults to auto, which will cache the results of the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional The number of retries for the task, defaults to 0, which means no retries. |
| `timeout` | `Union[timedelta, int]` | Optional The timeout for the task. |
| `docs` | `Optional[Documentation]` | Optional The documentation for the task, if not provided the function docstring will be used. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional The pod template for the task, if not provided the default pod template will be used. |
| `report` | `bool` | Optional Whether to generate the html report for the task, defaults to False. |
| `interruptible` | `bool \| None` | Optional Whether the task is interruptible, defaults to environment setting. |
| `max_inline_io_bytes` | `int` | Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes. |
| `queue` | `Optional[str]` | Optional queue name to use for this task. If not set, the environment's queue will be used. |
| `triggers` | `Tuple[Trigger, ...] \| Trigger` | Optional A tuple of triggers to associate with the task. This allows the task to be run on a schedule or in response to events. Triggers can be defined using the `flyte.trigger` module. |
| `links` | `Tuple[Link, ...] \| Link` | Optional A tuple of links to associate with the task. Links can be used to provide additional context or information about the task. Links should implement the `flyte.Link` protocol |
| `task_resolver` | `Any \| None` | |
**Returns:** A TaskTemplate that can be used to deploy the task.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/timeout ===
# Timeout
**Package:** `flyte`
Timeout class to define a timeout for a task.
The task timeout can be set to a maximum runtime and a maximum queued time.
Maximum runtime is the maximum time the task can run for (in one attempt).
Maximum queued time is the maximum time the task can stay in the queue before it starts executing.
Example usage:
```python
timeout = Timeout(max_runtime=timedelta(minutes=5), max_queued_time=timedelta(minutes=10))
@env.task(timeout=timeout)
async def my_task():
pass
```
## Parameters
```python
class Timeout(
max_runtime: datetime.timedelta | int,
max_queued_time: datetime.timedelta | int | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `max_runtime` | `datetime.timedelta \| int` | timedelta or int - Maximum runtime for the task. If specified int, it will be converted to timedelta as seconds. |
| `max_queued_time` | `datetime.timedelta \| int \| None` | optional, timedelta or int - Maximum queued time for the task. If specified int, it will be converted to timedelta as seconds. Defaults to None. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte/trigger ===
# Trigger
**Package:** `flyte`
Specification for a scheduled trigger that can be associated with any Flyte task.
Triggers run tasks on a schedule (cron or fixed-rate). They are set only in the
`@env.task` decorator via the `triggers` parameter. The same `Trigger` object
can be associated with multiple tasks.
Predefined convenience constructors are available: `Trigger.hourly()`,
`Trigger.daily()`, `Trigger.weekly()`, `Trigger.monthly()`, and
`Trigger.minutely()`.
Example:
```python
my_trigger = flyte.Trigger(
name="my_trigger",
description="A trigger that runs every hour",
inputs={"start_time": flyte.TriggerTime, "x": 1},
automation=flyte.FixedRate(60),
)
@env.task(triggers=[my_trigger])
async def my_task(start_time: datetime, x: int) -> str:
...
```
## Parameters
```python
class Trigger(
name: str,
automation: Union[Cron, FixedRate],
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Unique name for the trigger (required). |
| `automation` | `Union[Cron, FixedRate]` | Schedule type β `Cron(...)` or `FixedRate(...)` (required). |
| `description` | `str` | Human-readable description (max 255 characters). Default `""`. |
| `auto_activate` | `bool` | Whether to activate the trigger automatically on deployment. Default `True`. |
| `inputs` | `Dict[str, Any] \| None` | Default input values for triggered runs. Use `flyte.TriggerTime` to bind the trigger's scheduled time to an input parameter. |
| `env_vars` | `Dict[str, str] \| None` | Environment variables for triggered runs (overrides the task's configured values). |
| `interruptible` | `bool \| None` | Whether triggered runs use spot/preemptible instances. `None` (default) preserves the task's configured behavior. Overrides the task's configured value. |
| `overwrite_cache` | `bool` | Force cache refresh on triggered runs. Default `False`. |
| `queue` | `str \| None` | Queue name for triggered runs (overrides the task's configured value). |
| `labels` | `Mapping[str, str] \| None` | Kubernetes labels to attach to triggered runs. |
| `annotations` | `Mapping[str, str] \| None` | Kubernetes annotations to attach to triggered runs. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > Trigger > Methods > daily()** | Creates a Cron trigger that runs daily at midnight. |
| **Flyte SDK > Packages > flyte > Trigger > Methods > hourly()** | Creates a Cron trigger that runs every hour. |
| **Flyte SDK > Packages > flyte > Trigger > Methods > minutely()** | Creates a Cron trigger that runs every minute. |
| **Flyte SDK > Packages > flyte > Trigger > Methods > monthly()** | Creates a Cron trigger that runs monthly on the 1st at midnight. |
| **Flyte SDK > Packages > flyte > Trigger > Methods > weekly()** | Creates a Cron trigger that runs weekly on Sundays at midnight. |
### daily()
```python
def daily(
trigger_time_input_key: str | None,
name: str,
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
) -> Trigger
```
Creates a Cron trigger that runs daily at midnight.
| Parameter | Type | Description |
|-|-|-|
| `trigger_time_input_key` | `str \| None` | The input key for the trigger time. If None, no trigger time input is added. |
| `name` | `str` | The name of the trigger, default is "daily". |
| `description` | `str` | A description of the trigger. |
| `auto_activate` | `bool` | Whether the trigger should be automatically activated. |
| `inputs` | `Dict[str, Any] \| None` | Optional inputs for the trigger. |
| `env_vars` | `Dict[str, str] \| None` | Optional environment variables. |
| `interruptible` | `bool \| None` | Whether the triggered run is interruptible. |
| `overwrite_cache` | `bool` | Whether to overwrite the cache. |
| `queue` | `str \| None` | Optional queue to run the trigger in. |
| `labels` | `Mapping[str, str] \| None` | Optional labels to attach to the trigger. |
| `annotations` | `Mapping[str, str] \| None` | Optional annotations to attach to the trigger. |
**Returns:** Trigger: A trigger that runs daily at midnight.
### hourly()
```python
def hourly(
trigger_time_input_key: str | None,
name: str,
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
) -> Trigger
```
Creates a Cron trigger that runs every hour.
| Parameter | Type | Description |
|-|-|-|
| `trigger_time_input_key` | `str \| None` | The input parameter for the trigger time. If None, no trigger time input is added. |
| `name` | `str` | The name of the trigger, default is "hourly". |
| `description` | `str` | A description of the trigger. |
| `auto_activate` | `bool` | Whether the trigger should be automatically activated. |
| `inputs` | `Dict[str, Any] \| None` | Optional inputs for the trigger. |
| `env_vars` | `Dict[str, str] \| None` | Optional environment variables. |
| `interruptible` | `bool \| None` | Whether the trigger is interruptible. |
| `overwrite_cache` | `bool` | Whether to overwrite the cache. |
| `queue` | `str \| None` | Optional queue to run the trigger in. |
| `labels` | `Mapping[str, str] \| None` | Optional labels to attach to the trigger. |
| `annotations` | `Mapping[str, str] \| None` | Optional annotations to attach to the trigger. |
**Returns:** Trigger: A trigger that runs every hour, on the hour.
### minutely()
```python
def minutely(
trigger_time_input_key: str | None,
name: str,
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
) -> Trigger
```
Creates a Cron trigger that runs every minute.
| Parameter | Type | Description |
|-|-|-|
| `trigger_time_input_key` | `str \| None` | The input parameter for the trigger time. If None, no trigger time input is added. |
| `name` | `str` | The name of the trigger, default is "every_minute". |
| `description` | `str` | A description of the trigger. |
| `auto_activate` | `bool` | Whether the trigger should be automatically activated. |
| `inputs` | `Dict[str, Any] \| None` | Optional inputs for the trigger. |
| `env_vars` | `Dict[str, str] \| None` | Optional environment variables. |
| `interruptible` | `bool \| None` | Whether the trigger is interruptible. |
| `overwrite_cache` | `bool` | Whether to overwrite the cache. |
| `queue` | `str \| None` | Optional queue to run the trigger in. |
| `labels` | `Mapping[str, str] \| None` | Optional labels to attach to the trigger. |
| `annotations` | `Mapping[str, str] \| None` | Optional annotations to attach to the trigger. |
**Returns:** Trigger: A trigger that runs every minute.
### monthly()
```python
def monthly(
trigger_time_input_key: str | None,
name: str,
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
) -> Trigger
```
Creates a Cron trigger that runs monthly on the 1st at midnight.
| Parameter | Type | Description |
|-|-|-|
| `trigger_time_input_key` | `str \| None` | The input parameter for the trigger time. If None, no trigger time input is added. |
| `name` | `str` | The name of the trigger, default is "monthly". |
| `description` | `str` | A description of the trigger. |
| `auto_activate` | `bool` | Whether the trigger should be automatically activated. |
| `inputs` | `Dict[str, Any] \| None` | Optional inputs for the trigger. |
| `env_vars` | `Dict[str, str] \| None` | Optional environment variables. |
| `interruptible` | `bool \| None` | Whether the trigger is interruptible. |
| `overwrite_cache` | `bool` | Whether to overwrite the cache. |
| `queue` | `str \| None` | Optional queue to run the trigger in. |
| `labels` | `Mapping[str, str] \| None` | Optional labels to attach to the trigger. |
| `annotations` | `Mapping[str, str] \| None` | Optional annotations to attach to the trigger. |
**Returns:** Trigger: A trigger that runs monthly on the 1st at midnight.
### weekly()
```python
def weekly(
trigger_time_input_key: str | None,
name: str,
description: str,
auto_activate: bool,
inputs: Dict[str, Any] | None,
env_vars: Dict[str, str] | None,
interruptible: bool | None,
overwrite_cache: bool,
queue: str | None,
labels: Mapping[str, str] | None,
annotations: Mapping[str, str] | None,
) -> Trigger
```
Creates a Cron trigger that runs weekly on Sundays at midnight.
| Parameter | Type | Description |
|-|-|-|
| `trigger_time_input_key` | `str \| None` | The input parameter for the trigger time. If None, no trigger time input is added. |
| `name` | `str` | The name of the trigger, default is "weekly". |
| `description` | `str` | A description of the trigger. |
| `auto_activate` | `bool` | Whether the trigger should be automatically activated. |
| `inputs` | `Dict[str, Any] \| None` | Optional inputs for the trigger. |
| `env_vars` | `Dict[str, str] \| None` | Optional environment variables. |
| `interruptible` | `bool \| None` | Whether the trigger is interruptible. |
| `overwrite_cache` | `bool` | Whether to overwrite the cache. |
| `queue` | `str \| None` | Optional queue to run the trigger in. |
| `labels` | `Mapping[str, str] \| None` | Optional labels to attach to the trigger. |
| `annotations` | `Mapping[str, str] \| None` | Optional annotations to attach to the trigger. |
**Returns:** Trigger: A trigger that runs weekly on Sundays at midnight.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app ===
# flyte.app
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint** | Embed an upstream app's endpoint as an app parameter. |
| **Flyte SDK > Packages > flyte.app > AppEnvironment** | Configure a long-running app environment for APIs, dashboards, or model servers. |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment** | |
| **Flyte SDK > Packages > flyte.app > Domain** | Subdomain to use for the domain. |
| **Flyte SDK > Packages > flyte.app > Link** | Custom links to add to the app. |
| **Flyte SDK > Packages > flyte.app > Parameter** | Parameter for application. |
| **Flyte SDK > Packages > flyte.app > Port** | |
| **Flyte SDK > Packages > flyte.app > RunOutput** | Use a run's output for app parameters. |
| **Flyte SDK > Packages > flyte.app > Scaling** | Controls replica count and autoscaling behavior for app environments. |
| **Flyte SDK > Packages > flyte.app > Timeouts** | Timeout configuration for the application. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > Methods > ctx()** | Returns the current app context. |
| **Flyte SDK > Packages > flyte.app > Methods > get_parameter()** | Get parameters for application or endpoint. |
## Methods
#### ctx()
```python
def ctx()
```
Returns the current app context.
Returns: AppContext
#### get_parameter()
```python
def get_parameter(
name: str,
) -> str
```
Get parameters for application or endpoint.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
## Subpages
- **Flyte SDK > Packages > flyte.app > AppEndpoint**
- **Flyte SDK > Packages > flyte.app > AppEnvironment**
- **Flyte SDK > Packages > flyte.app > ConnectorEnvironment**
- **Flyte SDK > Packages > flyte.app > Domain**
- **Flyte SDK > Packages > flyte.app > Link**
- **Flyte SDK > Packages > flyte.app > Parameter**
- **Flyte SDK > Packages > flyte.app > Port**
- **Flyte SDK > Packages > flyte.app > RunOutput**
- **Flyte SDK > Packages > flyte.app > Scaling**
- **Flyte SDK > Packages > flyte.app > Timeouts**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/appendpoint ===
# AppEndpoint
**Package:** `flyte.app`
Embed an upstream app's endpoint as an app parameter.
This enables the declaration of an app parameter dependency on a the endpoint of
an upstream app, given by a specific app name. This gives the app access to
the upstream app's endpoint as a public or private url.
## Parameters
```python
class AppEndpoint(
type: typing.Literal['string'],
app_name: str,
public: bool,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `type` | `typing.Literal['string']` | |
| `app_name` | `str` | |
| `public` | `bool` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > check_type()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > materialize()** | Returns the AppEndpoint object, the endpoint is retrieved at serving time by the fserve executable. |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### get()
```python
def get()
```
### materialize()
```python
def materialize()
```
Returns the AppEndpoint object, the endpoint is retrieved at serving time by the fserve executable.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/appenvironment ===
# AppEnvironment
**Package:** `flyte.app`
Configure a long-running app environment for APIs, dashboards, or model servers.
Example:
```python
app_env = flyte.app.AppEnvironment(
name="my-api",
image=flyte.Image.from_debian_base(python="3.12").with_pip_packages("fastapi", "uvicorn"),
port=8080,
scaling=flyte.app.Scaling(replicas=(1, 3)),
)
```
## Parameters
```python
class AppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto'], None],
type: Optional[str],
port: int | Port,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of the app (required). Must be lowercase alphanumeric with hyphens. Inherited from Environment. |
| `depends_on` | `List[Environment]` | Dependencies on other environments (deployed together). Inherited from Environment. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | Secrets to inject. Inherited from Environment. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables. Inherited from Environment. |
| `resources` | `Optional[Resources]` | Compute resources (CPU, memory, GPU). Inherited from Environment. |
| `interruptible` | `bool` | |
| `image` | `Union[str, Image, Literal['auto'], None]` | Docker image for the environment. Inherited from Environment. |
| `type` | `Optional[str]` | App type identifier (e.g., `"streamlit"`, `"fastapi"`). When set, the platform may apply framework-specific defaults. |
| `port` | `int \| Port` | Port for the app server. Default `8080`. Ports 8012, 8022, 8112, 9090, and 9091 are reserved and cannot be used. Can also be a `Port` object for advanced configuration. |
| `args` | `*args` | Arguments passed to the app process. Can be a list of strings or a single string. Used for script-based apps (e.g., Streamlit's `["--server.port", "8080"]`). |
| `command` | `Optional[Union[List[str], str]]` | Full command to run in the container. Alternative to `args` β use when you need to override the container's entrypoint entirely. |
| `requires_auth` | `bool` | Whether the app endpoint requires authentication. Default `True`. Set to `False` for public endpoints. |
| `scaling` | `Scaling` | `Scaling` object controlling replicas and autoscaling behavior. Default is `Scaling()` (scale-to-zero, max 1 replica). |
| `domain` | `Domain \| None` | `Domain` object for custom domain configuration. |
| `links` | `List[Link]` | List of `Link` objects for connecting to other environments. |
| `include` | `List[str]` | List of additional file paths to bundle with the app (e.g., utility modules, config files, data files). |
| `parameters` | `List[Parameter]` | List of `Parameter` objects for app inputs. Use `RunOutput` to connect app parameters to task outputs, or `AppEndpoint` to reference other app endpoints. |
| `cluster_pool` | `str` | Cluster pool for scheduling. Default `"default"`. |
| `timeouts` | `Timeouts` | `Timeouts` object for startup/health check timeouts. |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > clone_with()** | |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > container_args()** | |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > container_cmd()** | |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > get_port()** | |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Flyte SDK > Packages > flyte.app > AppEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> AppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
parameter_overrides: list[Parameter] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `parameter_overrides` | `list[Parameter] \| None` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/connectorenvironment ===
# ConnectorEnvironment
**Package:** `flyte.app`
## Parameters
```python
class ConnectorEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto'], None],
type: str,
port: int | flyte.app._types.Port,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `depends_on` | `List[Environment]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `resources` | `Optional[Resources]` | |
| `interruptible` | `bool` | |
| `image` | `Union[str, Image, Literal['auto'], None]` | |
| `type` | `str` | |
| `port` | `int \| flyte.app._types.Port` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | |
| `scaling` | `Scaling` | |
| `domain` | `Domain \| None` | |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `parameters` | `List[Parameter]` | |
| `cluster_pool` | `str` | |
| `timeouts` | `Timeouts` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > clone_with()** | |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > container_args()** | |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > container_cmd()** | |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > get_port()** | |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Flyte SDK > Packages > flyte.app > ConnectorEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> AppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialize_context: flyte.models.SerializationContext,
) -> typing.List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `flyte.models.SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: flyte.models.SerializationContext,
parameter_overrides: list[flyte.app._parameter.Parameter] | None,
) -> typing.List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `flyte.models.SerializationContext` | |
| `parameter_overrides` | `list[flyte.app._parameter.Parameter] \| None` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/domain ===
# Domain
**Package:** `flyte.app`
Subdomain to use for the domain. If not set, the default subdomain will be used.
## Parameters
```python
class Domain(
subdomain: typing.Optional[str],
custom_domain: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `subdomain` | `typing.Optional[str]` | |
| `custom_domain` | `typing.Optional[str]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/link ===
# Link
**Package:** `flyte.app`
Custom links to add to the app
## Parameters
```python
class Link(
path: str,
title: str,
is_relative: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `title` | `str` | |
| `is_relative` | `bool` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/parameter ===
# Parameter
**Package:** `flyte.app`
Parameter for application.
## Parameters
```python
class Parameter(
name: str,
value: ParameterTypes | _DelayedValue,
env_var: Optional[str],
download: bool,
mount: Optional[str],
ignore_patterns: list[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of parameter. |
| `value` | `ParameterTypes \| _DelayedValue` | Value for parameter. |
| `env_var` | `Optional[str]` | Environment name to set the value in the serving environment. |
| `download` | `bool` | When True, the parameter will be automatically downloaded. This only works if the value refers to an item in a object store. i.e. `s3://...` |
| `mount` | `Optional[str]` | If `value` is a directory, then the directory will be available at `mount`. If `value` is a file, then the file will be downloaded into the `mount` directory. |
| `ignore_patterns` | `list[str]` | If `value` is a directory, then this is a list of glob patterns to ignore. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/port ===
# Port
**Package:** `flyte.app`
## Parameters
```python
class Port(
port: int,
name: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `port` | `int` | |
| `name` | `typing.Optional[str]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/runoutput ===
# RunOutput
**Package:** `flyte.app`
Use a run's output for app parameters.
This enables the declaration of an app parameter dependency on the output of
a run, given by a specific run name, or a task name and version. If
`task_auto_version == 'latest'`, the latest version of the task will be used.
If `task_auto_version == 'current'`, the version will be derived from the callee
app or task context. To get the latest task run for ephemeral task runs, set
`task_version` and `task_auto_version` should both be set to `None` (which is the default).
Examples:
Get the output of a specific run:
```python
run_output = RunOutput(type="directory", run_name="my-run-123")
```
Get the latest output of an ephemeral task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task")
```
Get the latest output of a deployed task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task", task_auto_version="latest")
```
Get the output of a specific task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task", task_version="xyz")
```
## Parameters
```python
class RunOutput(
type: typing.Literal['string', 'file', 'directory', 'app_endpoint'],
run_name: str | None,
task_name: str | None,
task_version: str | None,
task_auto_version: typing.Optional[typing.Literal['latest', 'current']],
getter: tuple[typing.Any, ...],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `type` | `typing.Literal['string', 'file', 'directory', 'app_endpoint']` | |
| `run_name` | `str \| None` | |
| `task_name` | `str \| None` | |
| `task_version` | `str \| None` | |
| `task_auto_version` | `typing.Optional[typing.Literal['latest', 'current']]` | |
| `getter` | `tuple[typing.Any, ...]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > check_type()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > materialize()** | |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### get()
```python
def get()
```
### materialize()
```python
def materialize()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/scaling ===
# Scaling
**Package:** `flyte.app`
Controls replica count and autoscaling behavior for app environments.
Common scaling patterns:
- **Scale-to-zero** (default): `Scaling(replicas=(0, 1))` β no replicas when idle,
scales to 1 on demand.
- **Always-on**: `Scaling(replicas=(1, 1))` β exactly 1 replica at all times.
- **Burstable**: `Scaling(replicas=(1, 5))` β 1 replica minimum, scales up to 5.
- **High-availability**: `Scaling(replicas=(2, 10))` β at least 2 replicas always running.
- **Fixed size**: `Scaling(replicas=3)` β exactly 3 replicas.
## Parameters
```python
class Scaling(
replicas: typing.Union[int, typing.Tuple[int, int]],
metric: typing.Union[flyte.app._types.Scaling.Concurrency, flyte.app._types.Scaling.RequestRate, NoneType],
scaledown_after: int | datetime.timedelta | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `replicas` | `typing.Union[int, typing.Tuple[int, int]]` | Number of replicas. An `int` for fixed count, or a `(min, max)` tuple for autoscaling. Default `(0, 1)`. |
| `metric` | `typing.Union[flyte.app._types.Scaling.Concurrency, flyte.app._types.Scaling.RequestRate, NoneType]` | Autoscaling metric β `Scaling.Concurrency(val)` (scale when concurrent requests per replica exceeds `val`) or `Scaling.RequestRate(val)` (scale when requests per second per replica exceeds `val`). Default `None`. |
| `scaledown_after` | `int \| datetime.timedelta \| None` | Time to wait after the last request before scaling down. Seconds (`int`) or `timedelta`. Default `None` (platform default). |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > Scaling > Methods > get_replicas()** | |
### get_replicas()
```python
def get_replicas()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/timeouts ===
# Timeouts
**Package:** `flyte.app`
Timeout configuration for the application.
Attributes:
request: Timeout for requests to the application. Can be an int
(seconds) or timedelta. Must not exceed 1 hour.
## Parameters
```python
class Timeouts(
request: int | datetime.timedelta | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `request` | `int \| datetime.timedelta \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras ===
# flyte.app.extras
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware** | FastAPI middleware that automatically sets Flyte auth metadata from request headers. |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment** | A pre-built FastAPI app environment for common Flyte webhook operations. |
## Subpages
- **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment**
- **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware**
- **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras/fastapiappenvironment ===
# FastAPIAppEnvironment
**Package:** `flyte.app.extras`
## Parameters
```python
class FastAPIAppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto'], None],
port: int | Port,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
type: str,
app: fastapi.FastAPI,
uvicorn_config: uvicorn.Config | None,
_caller_frame: inspect.FrameInfo | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `depends_on` | `List[Environment]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `resources` | `Optional[Resources]` | |
| `interruptible` | `bool` | |
| `image` | `Union[str, Image, Literal['auto'], None]` | |
| `port` | `int \| Port` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | |
| `scaling` | `Scaling` | |
| `domain` | `Domain \| None` | |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `parameters` | `List[Parameter]` | |
| `cluster_pool` | `str` | |
| `timeouts` | `Timeouts` | |
| `type` | `str` | |
| `app` | `fastapi.FastAPI` | |
| `uvicorn_config` | `uvicorn.Config \| None` | |
| `_caller_frame` | `inspect.FrameInfo \| None` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > clone_with()** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > container_args()** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > container_cmd()** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > container_command()** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > get_port()** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> AppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
parameter_overrides: list[Parameter] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `parameter_overrides` | `list[Parameter] \| None` | |
### container_command()
```python
def container_command(
serialization_context: SerializationContext,
) -> list[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `SerializationContext` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras/fastapipassthroughauthmiddleware ===
# FastAPIPassthroughAuthMiddleware
**Package:** `flyte.app.extras`
FastAPI middleware that automatically sets Flyte auth metadata from request headers.
This middleware extracts authentication headers from incoming HTTP requests and
sets them in the Flyte context using the auth_metadata() context manager. This
eliminates the need to manually wrap endpoint handlers with auth_metadata().
The middleware is highly configurable:
- Custom header extractors can be provided
- Specific paths can be excluded from auth requirements
- Auth can be optional or required
Thread Safety:
This middleware is async-safe and properly isolates auth metadata per request
using Python's contextvars. Multiple concurrent requests with different
authentication will not interfere with each other.
## Parameters
```python
class FastAPIPassthroughAuthMiddleware(
app,
header_extractors: list[HeaderExtractor] | None,
excluded_paths: set[str] | None,
)
```
Initialize the Flyte authentication middleware.
| Parameter | Type | Description |
|-|-|-|
| `app` | | The FastAPI application (this is a mandatory framework parameter) |
| `header_extractors` | `list[HeaderExtractor] \| None` | List of functions to extract headers from requests |
| `excluded_paths` | `set[str] \| None` | Set of URL paths that bypass auth extraction |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware > Methods > dispatch()** | Process each request, extracting auth headers and setting Flyte context. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware > Methods > extract_authorization_header()** | Extract the Authorization header from the request. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware > Methods > extract_cookie_header()** | Extract the Cookie header from the request. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIPassthroughAuthMiddleware > Methods > extract_custom_header()** | Create a header extractor for a custom header name. |
### dispatch()
```python
def dispatch(
request: 'Request',
call_next,
) -> 'Response'
```
Process each request, extracting auth headers and setting Flyte context.
| Parameter | Type | Description |
|-|-|-|
| `request` | `'Request'` | The incoming HTTP request |
| `call_next` | | The next middleware or route handler to call |
**Returns:** The HTTP response from the handler
### extract_authorization_header()
```python
def extract_authorization_header(
request: 'Request',
) -> tuple[str, str] | None
```
Extract the Authorization header from the request.
| Parameter | Type | Description |
|-|-|-|
| `request` | `'Request'` | The FastAPI/Starlette request object |
**Returns:** Tuple of ("authorization", header_value) if present, None otherwise
### extract_cookie_header()
```python
def extract_cookie_header(
request: 'Request',
) -> tuple[str, str] | None
```
Extract the Cookie header from the request.
| Parameter | Type | Description |
|-|-|-|
| `request` | `'Request'` | The FastAPI/Starlette request object |
**Returns:** Tuple of ("cookie", header_value) if present, None otherwise
### extract_custom_header()
```python
def extract_custom_header(
header_name: str,
) -> HeaderExtractor
```
Create a header extractor for a custom header name.
Example::
# Create extractor for X-API-Key header
api_key_extractor = extract_custom_header("x-api-key")
app.add_middleware(
FastAPIAuthMiddleware,
header_extractors=[api_key_extractor],
)
| Parameter | Type | Description |
|-|-|-|
| `header_name` | `str` | The name of the header to extract (case-insensitive) |
**Returns**
A header extractor function that extracts the specified header
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras/flytewebhookappenvironment ===
# FlyteWebhookAppEnvironment
**Package:** `flyte.app.extras`
A pre-built FastAPI app environment for common Flyte webhook operations.
This environment provides a ready-to-use FastAPI application with endpoints for:
- Running tasks in a specific domain/project/version
- Getting run I/O and metadata
- Aborting runs
- Getting task metadata
- Building images
- Activating/deactivating apps (except itself)
- Getting app status
- Calling other app endpoints
- Activating/deactivating triggers
- Prefetching HuggingFace models (run, status, I/O, abort)
All endpoints use FastAPIPassthroughAuthMiddleware for authentication.
Example:
Basic usage (all endpoints enabled):
```python
import flyte
from flyte.app.extras import FlyteWebhookAppEnvironment
webhook_env = FlyteWebhookAppEnvironment(
name="my-webhook",
image=flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# Deploy the webhook
flyte.serve(webhook_env)
```
With endpoint group filtering:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Only enable core, task, and run endpoint groups
webhook_env = FlyteWebhookAppEnvironment(
name="task-runner-webhook",
endpoint_groups=["core", "task", "run"],
)
```
With individual endpoint filtering:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Only enable specific endpoints
webhook_env = FlyteWebhookAppEnvironment(
name="minimal-webhook",
endpoints=["health", "run_task", "get_run"],
)
```
Combining endpoint groups and individual endpoints:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Enable core group plus specific additional endpoints
webhook_env = FlyteWebhookAppEnvironment(
name="custom-webhook",
endpoint_groups=["core"],
endpoints=["run_task", "get_run"],
)
```
With task allow-listing:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Only allow specific tasks
webhook_env = FlyteWebhookAppEnvironment(
name="restricted-webhook",
endpoint_groups=["core", "task", "run"],
task_allowlist=["production/my-project/allowed-task", "my-other-task"],
)
```
With app allow-listing:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Only allow specific apps
webhook_env = FlyteWebhookAppEnvironment(
name="app-manager-webhook",
endpoint_groups=["core", "app"],
app_allowlist=["my-app", "another-app"],
)
```
With trigger allow-listing:
```python
from flyte.app.extras import FlyteWebhookAppEnvironment
# Only allow specific triggers
webhook_env = FlyteWebhookAppEnvironment(
name="trigger-manager-webhook",
endpoint_groups=["core", "trigger"],
trigger_allowlist=["my-task/my-trigger", "another-trigger"],
)
```
## Parameters
```python
class FlyteWebhookAppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
port: int | Port,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
image: flyte.Image,
type: str,
uvicorn_config: 'uvicorn.Config | None',
_caller_frame: inspect.FrameInfo | None,
title: str | None,
endpoint_groups: list[WebhookEndpointGroup] | tuple[WebhookEndpointGroup, ...] | None,
endpoints: list[WebhookEndpoint] | tuple[WebhookEndpoint, ...] | None,
task_allowlist: list[str] | None,
app_allowlist: list[str] | None,
trigger_allowlist: list[str] | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of the webhook app environment |
| `depends_on` | `List[Environment]` | Environment dependencies |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | Description for the FastAPI app (optional) |
| `secrets` | `Optional[SecretRequest]` | Secrets to inject into the environment |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `resources` | `Optional[Resources]` | Resources to allocate for the environment |
| `interruptible` | `bool` | |
| `port` | `int \| Port` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | Whether the app requires authentication (default: True) |
| `scaling` | `Scaling` | Scaling configuration for the app environment |
| `domain` | `Domain \| None` | |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `parameters` | `List[Parameter]` | |
| `cluster_pool` | `str` | |
| `timeouts` | `Timeouts` | |
| `image` | `flyte.Image` | Docker image to use for the environment |
| `type` | `str` | |
| `uvicorn_config` | `'uvicorn.Config \| None'` | |
| `_caller_frame` | `inspect.FrameInfo \| None` | |
| `title` | `str \| None` | Title for the FastAPI app (optional) |
| `endpoint_groups` | `list[WebhookEndpointGroup] \| tuple[WebhookEndpointGroup, ...] \| None` | List of endpoint groups to enable. If None (and endpoints is None), all endpoints are enabled. Available groups (see WebhookEndpointGroup type): - "all": All available endpoints - "core": Health check and user info ("health", "me") - "task": Task operations ("run_task", "get_task") - "run": Run operations ("get_run", "get_run_io", "abort_run") - "app": App operations ("get_app", "activate_app", "deactivate_app", "call_app") - "trigger": Trigger operations ("activate_trigger", "deactivate_trigger") - "build": Image build operations ("build_image") - "prefetch": HuggingFace prefetch operations ("prefetch_hf_model", "get_prefetch_hf_model", "get_prefetch_hf_model_io", "abort_prefetch_hf_model") |
| `endpoints` | `list[WebhookEndpoint] \| tuple[WebhookEndpoint, ...] \| None` | List of individual endpoints to enable. Can be used alone or combined with endpoint_groups. Available endpoints (see WebhookEndpoint type): - "health": Health check endpoint - "me": Get current user info - "run_task": Run a task - "get_task": Get task metadata - "get_run": Get run status - "get_run_io": Get run inputs/outputs - "abort_run": Abort a run - "get_app": Get app status - "activate_app": Activate an app - "deactivate_app": Deactivate an app - "call_app": Call another app's endpoint - "activate_trigger": Activate a trigger - "deactivate_trigger": Deactivate a trigger - "build_image": Build a container image - "prefetch_hf_model": Prefetch a HuggingFace model - "get_prefetch_hf_model": Get prefetch run status - "get_prefetch_hf_model_io": Get prefetch run I/O - "abort_prefetch_hf_model": Abort a prefetch run |
| `task_allowlist` | `list[str] \| None` | List of allowed task identifiers. When set, only tasks matching the allowlist can be accessed via task endpoints. Supports formats: - "domain/project/name" for exact match - "project/name" for project/name match (any domain) - "name" for name-only match (any domain/project) |
| `app_allowlist` | `list[str] \| None` | List of allowed app names. When set, only apps matching the allowlist can be accessed via app endpoints. |
| `trigger_allowlist` | `list[str] \| None` | List of allowed trigger identifiers. When set, only triggers matching the allowlist can be accessed via trigger endpoints. Supports formats: - "task_name/trigger_name" for exact match - "trigger_name" for name-only match (any task) |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > clone_with()** | |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > container_args()** | |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > container_cmd()** | |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > container_command()** | |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > get_port()** | |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Flyte SDK > Packages > flyte.app.extras > FlyteWebhookAppEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> AppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
parameter_overrides: list[Parameter] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `parameter_overrides` | `list[Parameter] \| None` | |
### container_command()
```python
def container_command(
serialization_context: SerializationContext,
) -> list[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `SerializationContext` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.config ===
# flyte.config
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > Config** | This the parent configuration object and holds all the underlying configuration object types. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > Methods > auto()** | Automatically constructs the Config Object. |
| **Flyte SDK > Packages > flyte.config > Methods > set_if_exists()** | Given a dict `d` sets the key `k` with value of config `v`, if the config value `v` is set. |
## Methods
#### auto()
```python
def auto(
config_file: typing.Union[str, pathlib.Path, ConfigFile, None],
) -> Config
```
Automatically constructs the Config Object. The order of precedence is as follows
1. If specified, read the config from the provided file path.
2. If not specified, the config file is searched in the default locations.
a. ./config.yaml if it exists (current working directory)
b. ./.flyte/config.yaml if it exists (current working directory)
c. <git_root>/.flyte/config.yaml if it exists
d. `UCTL_CONFIG` environment variable
e. `FLYTECTL_CONFIG` environment variable
f. ~/.union/config.yaml if it exists
g. ~/.flyte/config.yaml if it exists
3. If any value is not found in the config file, the default value is used.
4. For any value there are environment variables that match the config variable names, those will override
| Parameter | Type | Description |
|-|-|-|
| `config_file` | `typing.Union[str, pathlib.Path, ConfigFile, None]` | file path to read the config from, if not specified default locations are searched |
**Returns:** Config
#### set_if_exists()
```python
def set_if_exists(
d: dict,
k: str,
val: typing.Any,
) -> dict
```
Given a dict `d` sets the key `k` with value of config `v`, if the config value `v` is set
and return the updated dictionary.
| Parameter | Type | Description |
|-|-|-|
| `d` | `dict` | |
| `k` | `str` | |
| `val` | `typing.Any` | |
## Subpages
- **Flyte SDK > Packages > flyte.config > Config**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.config/config ===
# Config
**Package:** `flyte.config`
This the parent configuration object and holds all the underlying configuration object types. An instance of
this object holds all the config necessary to
1. Interactive session with Flyte backend
2. Some parts are required for Serialization, for example Platform Config is not required
3. Runtime of a task
## Parameters
```python
class Config(
platform: PlatformConfig,
task: TaskConfig,
image: ImageConfig,
local: LocalConfig,
source: pathlib.Path | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `platform` | `PlatformConfig` | |
| `task` | `TaskConfig` | |
| `image` | `ImageConfig` | |
| `local` | `LocalConfig` | |
| `source` | `pathlib.Path \| None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > Config > Methods > auto()** | Automatically constructs the Config Object. |
| **Flyte SDK > Packages > flyte.config > Config > Methods > with_params()** | |
### auto()
```python
def auto(
config_file: typing.Union[str, pathlib.Path, ConfigFile, None],
) -> 'Config'
```
Automatically constructs the Config Object. The order of precedence is as follows
1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format.
2. If not found in environment then values ar read from the config file
3. If not found in the file, then the default values are used.
| Parameter | Type | Description |
|-|-|-|
| `config_file` | `typing.Union[str, pathlib.Path, ConfigFile, None]` | file path to read the config from, if not specified default locations are searched |
**Returns:** Config
### with_params()
```python
def with_params(
platform: PlatformConfig | None,
task: TaskConfig | None,
image: ImageConfig | None,
) -> 'Config'
```
| Parameter | Type | Description |
|-|-|-|
| `platform` | `PlatformConfig \| None` | |
| `task` | `TaskConfig \| None` | |
| `image` | `ImageConfig \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors ===
# flyte.connectors
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector** | This is the base class for all async connectors, and it defines the interface that all connectors must implement. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnectorExecutorMixin** | This mixin class is used to run the connector task locally, and it's only used for local execution. |
| **Flyte SDK > Packages > flyte.connectors > ConnectorRegistry** | This is the registry for all connectors. |
| **Flyte SDK > Packages > flyte.connectors > ConnectorService** | |
| **Flyte SDK > Packages > flyte.connectors > Resource** | This is the output resource of the job. |
| **Flyte SDK > Packages > flyte.connectors > ResourceMeta** | This is the metadata for the job. |
## Subpages
- **Flyte SDK > Packages > flyte.connectors > AsyncConnector**
- **Flyte SDK > Packages > flyte.connectors > AsyncConnectorExecutorMixin**
- **Flyte SDK > Packages > flyte.connectors > ConnectorRegistry**
- **Flyte SDK > Packages > flyte.connectors > ConnectorService**
- **Flyte SDK > Packages > flyte.connectors > Resource**
- **Flyte SDK > Packages > flyte.connectors > ResourceMeta**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/asyncconnector ===
# AsyncConnector
**Package:** `flyte.connectors`
This is the base class for all async connectors, and it defines the interface that all connectors must implement.
The connector service is responsible for invoking connectors.
The executor will communicate with the connector service to create tasks, get the status of tasks, and delete tasks.
All the connectors should be registered in the ConnectorRegistry.
Connector Service will look up the connector based on the task type and version.
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector > Methods > create()** | Return a resource meta that can be used to get the status of the task. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector > Methods > delete()** | Delete the task. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector > Methods > get()** | Return the status of the task, and return the outputs in some cases. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector > Methods > get_logs()** | Return the metrics for the task. |
| **Flyte SDK > Packages > flyte.connectors > AsyncConnector > Methods > get_metrics()** | Return the metrics for the task. |
### create()
```python
def create(
task_template: flyteidl2.core.tasks_pb2.TaskTemplate,
output_prefix: str,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
task_execution_metadata: typing.Optional[flyteidl2.connector.connector_pb2.TaskExecutionMetadata],
kwargs,
) -> flyte.connectors._connector.ResourceMeta
```
Return a resource meta that can be used to get the status of the task.
| Parameter | Type | Description |
|-|-|-|
| `task_template` | `flyteidl2.core.tasks_pb2.TaskTemplate` | |
| `output_prefix` | `str` | |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Any]]` | |
| `task_execution_metadata` | `typing.Optional[flyteidl2.connector.connector_pb2.TaskExecutionMetadata]` | |
| `kwargs` | `**kwargs` | |
### delete()
```python
def delete(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
)
```
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get()
```python
def get(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyte.connectors._connector.Resource
```
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get_logs()
```python
def get_logs(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskLogsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get_metrics()
```python
def get_metrics(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskMetricsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/asyncconnectorexecutormixin ===
# AsyncConnectorExecutorMixin
**Package:** `flyte.connectors`
This mixin class is used to run the connector task locally, and it's only used for local execution.
Task should inherit from this class if the task can be run in the connector.
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > AsyncConnectorExecutorMixin > Methods > execute()** | |
### execute()
```python
def execute(
kwargs,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/connectorregistry ===
# ConnectorRegistry
**Package:** `flyte.connectors`
This is the registry for all connectors.
The connector service will look up the connector registry based on the task type and version.
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > ConnectorRegistry > Methods > get_connector()** | |
| **Flyte SDK > Packages > flyte.connectors > ConnectorRegistry > Methods > register()** | |
### get_connector()
```python
def get_connector(
task_type_name: str,
task_type_version: int,
) -> flyte.connectors._connector.AsyncConnector
```
| Parameter | Type | Description |
|-|-|-|
| `task_type_name` | `str` | |
| `task_type_version` | `int` | |
### register()
```python
def register(
connector: flyte.connectors._connector.AsyncConnector,
override: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `connector` | `flyte.connectors._connector.AsyncConnector` | |
| `override` | `bool` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/connectorservice ===
# ConnectorService
**Package:** `flyte.connectors`
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > ConnectorService > Methods > run()** | |
### run()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ConnectorService.run.aio()`.
```python
def run(
cls,
port: int,
prometheus_port: int,
worker: int,
timeout: int | None,
modules: typing.Optional[typing.List[str]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `port` | `int` | |
| `prometheus_port` | `int` | |
| `worker` | `int` | |
| `timeout` | `int \| None` | |
| `modules` | `typing.Optional[typing.List[str]]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/resource ===
# Resource
**Package:** `flyte.connectors`
This is the output resource of the job.
Attributes
----------
phase : TaskExecution.Phase
The phase of the job.
message : Optional[str]
The return message from the job.
log_links : Optional[List[TaskLog]]
The log links of the job. For example, the link to the BigQuery Console.
outputs : Optional[Union[LiteralMap, typing.Dict[str, Any]]]
The outputs of the job. If return python native types, the agent will convert them to flyte literals.
custom_info : Optional[typing.Dict[str, Any]]
The custom info of the job. For example, the job config.
## Parameters
```python
class Resource(
phase: google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper,
message: typing.Optional[str],
log_links: typing.Optional[typing.List[flyteidl2.core.execution_pb2.TaskLog]],
outputs: typing.Optional[typing.Dict[str, typing.Any]],
custom_info: typing.Optional[typing.Dict[str, typing.Any]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `phase` | `google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper` | |
| `message` | `typing.Optional[str]` | |
| `log_links` | `typing.Optional[typing.List[flyteidl2.core.execution_pb2.TaskLog]]` | |
| `outputs` | `typing.Optional[typing.Dict[str, typing.Any]]` | |
| `custom_info` | `typing.Optional[typing.Dict[str, typing.Any]]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors/resourcemeta ===
# ResourceMeta
**Package:** `flyte.connectors`
This is the metadata for the job. For example, the id of the job.
## Parameters
```python
def ResourceMeta()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors > ResourceMeta > Methods > decode()** | Decode the resource meta from bytes. |
| **Flyte SDK > Packages > flyte.connectors > ResourceMeta > Methods > encode()** | Encode the resource meta to bytes. |
### decode()
```python
def decode(
data: bytes,
) -> ResourceMeta
```
Decode the resource meta from bytes.
| Parameter | Type | Description |
|-|-|-|
| `data` | `bytes` | |
### encode()
```python
def encode()
```
Encode the resource meta to bytes.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.connectors.utils ===
# flyte.connectors.utils
## Directory
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.connectors.utils > Methods > convert_to_flyte_phase()** | Convert the state from the connector to the phase in flyte. |
| **Flyte SDK > Packages > flyte.connectors.utils > Methods > is_terminal_phase()** | Return true if the phase is terminal. |
| **Flyte SDK > Packages > flyte.connectors.utils > Methods > print_metadata()** | |
## Methods
#### convert_to_flyte_phase()
```python
def convert_to_flyte_phase(
state: str,
) -> google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper
```
Convert the state from the connector to the phase in flyte.
| Parameter | Type | Description |
|-|-|-|
| `state` | `str` | |
#### is_terminal_phase()
```python
def is_terminal_phase(
phase: google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper,
) -> bool
```
Return true if the phase is terminal.
| Parameter | Type | Description |
|-|-|-|
| `phase` | `google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper` | |
#### print_metadata()
```python
def print_metadata()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.durable ===
# flyte.durable
Flyte durable utilities.
This module provides deterministic, crash-resilient replacements for time-related functions.
Usage of `time.time()`, `time.sleep()` or `asyncio.sleep()` introduces non-determinism.
The utilities here persist state across crashes and restarts, making workflows durable.
- `sleep` - a durable replacement for `time.sleep` / `asyncio.sleep`
- `time` - a durable replacement for `time.time`
- `now` - a durable replacement for `datetime.now`
## Directory
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.durable > Methods > now()** | Returns the current time for every unique invocation of durable_time. |
| **Flyte SDK > Packages > flyte.durable > Methods > sleep()** | durable_sleep enables the process to sleep for `seconds` seconds even if the process recovers from a crash. |
| **Flyte SDK > Packages > flyte.durable > Methods > time()** | Returns the current time for every unique invocation of durable_time. |
## Methods
#### now()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await now.aio()`.
```python
def now()
```
Returns the current time for every unique invocation of durable_time. If the same invocation is encountered
the previously returned time is returned again, ensuring determinism.
Similar to using `datetime.now()` just durable!
Returns: datetime.datetime
#### sleep()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await sleep.aio()`.
```python
def sleep(
seconds: float,
)
```
durable_sleep enables the process to sleep for `seconds` seconds even if the process recovers from a crash.
This method can be invoked multiple times. If the process crashes, the invocation of durable_sleep will behave
like as-if the process has been sleeping since the first time this method was invoked.
Examples:
```python
import flyte.durable
env = flyte.TaskEnvironment("env")
@env.task
async def main():
# Do something
my_work()
# Now we need to sleep for 1 hour before proceeding.
await flyte.durable.sleep.aio(3600) # Even if process crashes, it will resume and only sleep for
# 1 hour in agregate. If the scheduling takes longer, it
# will simply return immediately.
# thing to be done after 1 hour
my_work()
```
| Parameter | Type | Description |
|-|-|-|
| `seconds` | `float` | float time to sleep for |
#### time()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await time.aio()`.
```python
def time()
```
Returns the current time for every unique invocation of durable_time. If the same invocation is encountered again
the previously returned time is returned again, ensuring determinism.
Similar to using `time.time()` just durable!
Returns: float
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors ===
# flyte.errors
Exceptions raised by Union.
These errors are raised when the underlying task execution fails, either because of a user error, system error or an
unknown error.
## Directory
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > ActionAbortedError** | This error is raised when an action was aborted, externally. |
| **Flyte SDK > Packages > flyte.errors > ActionNotFoundError** | This error is raised when the user tries to access an action that does not exist. |
| **Flyte SDK > Packages > flyte.errors > BaseRuntimeError** | Base class for all Union runtime errors. |
| **Flyte SDK > Packages > flyte.errors > CodeBundleError** | This error is raised when the code bundle cannot be created, for example when no files are found to bundle. |
| **Flyte SDK > Packages > flyte.errors > CustomError** | This error is raised when the user raises a custom error. |
| **Flyte SDK > Packages > flyte.errors > DeploymentError** | This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| **Flyte SDK > Packages > flyte.errors > ImageBuildError** | This error is raised when the image build fails. |
| **Flyte SDK > Packages > flyte.errors > ImagePullBackOffError** | This error is raised when the image cannot be pulled. |
| **Flyte SDK > Packages > flyte.errors > InitializationError** | This error is raised when the Union system is tried to access without being initialized. |
| **Flyte SDK > Packages > flyte.errors > InlineIOMaxBytesBreached** | This error is raised when the inline IO max bytes limit is breached. |
| **Flyte SDK > Packages > flyte.errors > InvalidImageNameError** | This error is raised when the image name is invalid. |
| **Flyte SDK > Packages > flyte.errors > InvalidPackageError** | Raised when an invalid system package is detected during image build. |
| **Flyte SDK > Packages > flyte.errors > LogsNotYetAvailableError** | This error is raised when the logs are not yet available for a task. |
| **Flyte SDK > Packages > flyte.errors > ModuleLoadError** | This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| **Flyte SDK > Packages > flyte.errors > NonRecoverableError** | Raised when an error is encountered that is not recoverable. |
| **Flyte SDK > Packages > flyte.errors > NotInTaskContextError** | This error is raised when the user tries to access the task context outside of a task. |
| **Flyte SDK > Packages > flyte.errors > OOMError** | This error is raised when the underlying task execution fails because of an out-of-memory error. |
| **Flyte SDK > Packages > flyte.errors > OnlyAsyncIOSupportedError** | This error is raised when the user tries to use sync IO in an async task. |
| **Flyte SDK > Packages > flyte.errors > ParameterMaterializationError** | This error is raised when the user tries to use a Parameter in an App, that has delayed Materialization,. |
| **Flyte SDK > Packages > flyte.errors > PrimaryContainerNotFoundError** | This error is raised when the primary container is not found. |
| **Flyte SDK > Packages > flyte.errors > RemoteTaskNotFoundError** | This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > RemoteTaskUsageError** | This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > RestrictedTypeError** | This error is raised when the user uses a restricted type, for example current a Tuple is not supported for one. |
| **Flyte SDK > Packages > flyte.errors > RetriesExhaustedError** | This error is raised when the underlying task execution fails after all retries have been exhausted. |
| **Flyte SDK > Packages > flyte.errors > RuntimeDataValidationError** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > RuntimeSystemError** | This error is raised when the underlying task execution fails because of a system error. |
| **Flyte SDK > Packages > flyte.errors > RuntimeUnknownError** | This error is raised when the underlying task execution fails because of an unknown error. |
| **Flyte SDK > Packages > flyte.errors > RuntimeUserError** | This error is raised when the underlying task execution fails because of an error in the user's code. |
| **Flyte SDK > Packages > flyte.errors > SlowDownError** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > TaskInterruptedError** | This error is raised when the underlying task execution is interrupted. |
| **Flyte SDK > Packages > flyte.errors > TaskTimeoutError** | This error is raised when the underlying task execution runs for longer than the specified timeout. |
| **Flyte SDK > Packages > flyte.errors > TraceDoesNotAllowNestedTasksError** | This error is raised when the user tries to use a task from within a trace. |
| **Flyte SDK > Packages > flyte.errors > UnionRpcError** | This error is raised when communication with the Union server fails. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > Methods > silence_grpc_polling_error()** | Suppress specific gRPC polling errors in the event loop. |
## Methods
#### silence_grpc_polling_error()
```python
def silence_grpc_polling_error(
loop,
context,
)
```
Suppress specific gRPC polling errors in the event loop.
| Parameter | Type | Description |
|-|-|-|
| `loop` | | |
| `context` | | |
## Subpages
- **Flyte SDK > Packages > flyte.errors > ActionAbortedError**
- **Flyte SDK > Packages > flyte.errors > ActionNotFoundError**
- **Flyte SDK > Packages > flyte.errors > BaseRuntimeError**
- **Flyte SDK > Packages > flyte.errors > CodeBundleError**
- **Flyte SDK > Packages > flyte.errors > CustomError**
- **Flyte SDK > Packages > flyte.errors > DeploymentError**
- **Flyte SDK > Packages > flyte.errors > ImageBuildError**
- **Flyte SDK > Packages > flyte.errors > ImagePullBackOffError**
- **Flyte SDK > Packages > flyte.errors > InitializationError**
- **Flyte SDK > Packages > flyte.errors > InlineIOMaxBytesBreached**
- **Flyte SDK > Packages > flyte.errors > InvalidImageNameError**
- **Flyte SDK > Packages > flyte.errors > InvalidPackageError**
- **Flyte SDK > Packages > flyte.errors > LogsNotYetAvailableError**
- **Flyte SDK > Packages > flyte.errors > ModuleLoadError**
- **Flyte SDK > Packages > flyte.errors > NonRecoverableError**
- **Flyte SDK > Packages > flyte.errors > NotInTaskContextError**
- **Flyte SDK > Packages > flyte.errors > OnlyAsyncIOSupportedError**
- **Flyte SDK > Packages > flyte.errors > OOMError**
- **Flyte SDK > Packages > flyte.errors > ParameterMaterializationError**
- **Flyte SDK > Packages > flyte.errors > PrimaryContainerNotFoundError**
- **Flyte SDK > Packages > flyte.errors > RemoteTaskNotFoundError**
- **Flyte SDK > Packages > flyte.errors > RemoteTaskUsageError**
- **Flyte SDK > Packages > flyte.errors > RestrictedTypeError**
- **Flyte SDK > Packages > flyte.errors > RetriesExhaustedError**
- **Flyte SDK > Packages > flyte.errors > RuntimeDataValidationError**
- **Flyte SDK > Packages > flyte.errors > RuntimeSystemError**
- **Flyte SDK > Packages > flyte.errors > RuntimeUnknownError**
- **Flyte SDK > Packages > flyte.errors > RuntimeUserError**
- **Flyte SDK > Packages > flyte.errors > SlowDownError**
- **Flyte SDK > Packages > flyte.errors > TaskInterruptedError**
- **Flyte SDK > Packages > flyte.errors > TaskTimeoutError**
- **Flyte SDK > Packages > flyte.errors > TraceDoesNotAllowNestedTasksError**
- **Flyte SDK > Packages > flyte.errors > UnionRpcError**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/actionabortederror ===
# ActionAbortedError
**Package:** `flyte.errors`
This error is raised when an action was aborted, externally. The parent action will raise this error.
## Parameters
```python
class ActionAbortedError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/actionnotfounderror ===
# ActionNotFoundError
**Package:** `flyte.errors`
This error is raised when the user tries to access an action that does not exist.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/baseruntimeerror ===
# BaseRuntimeError
**Package:** `flyte.errors`
Base class for all Union runtime errors. These errors are raised when the underlying task execution fails, either
because of a user error, system error or an unknown error.
## Parameters
```python
class BaseRuntimeError(
code: str,
kind: typing.Literal['system', 'unknown', 'user'],
root_cause_message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `kind` | `typing.Literal['system', 'unknown', 'user']` | |
| `root_cause_message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/codebundleerror ===
# CodeBundleError
**Package:** `flyte.errors`
This error is raised when the code bundle cannot be created, for example when no files are found to bundle.
## Parameters
```python
class CodeBundleError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/customerror ===
# CustomError
**Package:** `flyte.errors`
This error is raised when the user raises a custom error.
## Parameters
```python
class CustomError(
code: str,
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > CustomError > Methods > from_exception()** | Create a CustomError from an exception. |
### from_exception()
```python
def from_exception(
e: Exception,
)
```
Create a CustomError from an exception. The exception's class name is used as the error code and the exception
message is used as the error message.
| Parameter | Type | Description |
|-|-|-|
| `e` | `Exception` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/deploymenterror ===
# DeploymentError
**Package:** `flyte.errors`
This error is raised when the deployment of a task fails, or some preconditions for deployment are not met.
## Parameters
```python
class DeploymentError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/imagebuilderror ===
# ImageBuildError
**Package:** `flyte.errors`
This error is raised when the image build fails.
## Parameters
```python
class ImageBuildError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/imagepullbackofferror ===
# ImagePullBackOffError
**Package:** `flyte.errors`
This error is raised when the image cannot be pulled.
## Parameters
```python
class ImagePullBackOffError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/initializationerror ===
# InitializationError
**Package:** `flyte.errors`
This error is raised when the Union system is tried to access without being initialized.
## Parameters
```python
class InitializationError(
code: str,
kind: typing.Literal['system', 'unknown', 'user'],
root_cause_message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `kind` | `typing.Literal['system', 'unknown', 'user']` | |
| `root_cause_message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/inlineiomaxbytesbreached ===
# InlineIOMaxBytesBreached
**Package:** `flyte.errors`
This error is raised when the inline IO max bytes limit is breached.
This can be adjusted per task by setting max_inline_io_bytes in the task definition.
## Parameters
```python
class InlineIOMaxBytesBreached(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/invalidimagenameerror ===
# InvalidImageNameError
**Package:** `flyte.errors`
This error is raised when the image name is invalid.
## Parameters
```python
class InvalidImageNameError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/invalidpackageerror ===
# InvalidPackageError
**Package:** `flyte.errors`
Raised when an invalid system package is detected during image build.
## Parameters
```python
class InvalidPackageError(
package_name: str,
original_error: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `package_name` | `str` | |
| `original_error` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/logsnotyetavailableerror ===
# LogsNotYetAvailableError
**Package:** `flyte.errors`
This error is raised when the logs are not yet available for a task.
## Parameters
```python
class LogsNotYetAvailableError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/moduleloaderror ===
# ModuleLoadError
**Package:** `flyte.errors`
This error is raised when the module cannot be loaded, either because it does not exist or because of a
syntax error.
## Parameters
```python
class ModuleLoadError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/nonrecoverableerror ===
# NonRecoverableError
**Package:** `flyte.errors`
Raised when an error is encountered that is not recoverable. Retries are irrelevant.
## Parameters
```python
class NonRecoverableError(
message: str,
code: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
| `code` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/notintaskcontexterror ===
# NotInTaskContextError
**Package:** `flyte.errors`
This error is raised when the user tries to access the task context outside of a task.
## Parameters
```python
class NotInTaskContextError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/onlyasynciosupportederror ===
# OnlyAsyncIOSupportedError
**Package:** `flyte.errors`
This error is raised when the user tries to use sync IO in an async task.
## Parameters
```python
class OnlyAsyncIOSupportedError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/oomerror ===
# OOMError
**Package:** `flyte.errors`
This error is raised when the underlying task execution fails because of an out-of-memory error.
## Parameters
```python
class OOMError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/parametermaterializationerror ===
# ParameterMaterializationError
**Package:** `flyte.errors`
This error is raised when the user tries to use a Parameter in an App, that has delayed Materialization,
but the materialization fails.
## Parameters
```python
class ParameterMaterializationError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/primarycontainernotfounderror ===
# PrimaryContainerNotFoundError
**Package:** `flyte.errors`
This error is raised when the primary container is not found.
## Parameters
```python
class PrimaryContainerNotFoundError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/remotetasknotfounderror ===
# RemoteTaskNotFoundError
**Package:** `flyte.errors`
This error is raised when the user tries to access a task that does not exist.
## Parameters
```python
class RemoteTaskNotFoundError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/remotetaskusageerror ===
# RemoteTaskUsageError
**Package:** `flyte.errors`
This error is raised when the user tries to access a task that does not exist.
## Parameters
```python
class RemoteTaskUsageError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/restrictedtypeerror ===
# RestrictedTypeError
**Package:** `flyte.errors`
This error is raised when the user uses a restricted type, for example current a Tuple is not supported for one
value.
## Parameters
```python
class RestrictedTypeError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/retriesexhaustederror ===
# RetriesExhaustedError
**Package:** `flyte.errors`
This error is raised when the underlying task execution fails after all retries have been exhausted.
## Parameters
```python
class RetriesExhaustedError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/runtimedatavalidationerror ===
# RuntimeDataValidationError
**Package:** `flyte.errors`
This error is raised when the user tries to access a resource that does not exist or is invalid.
## Parameters
```python
class RuntimeDataValidationError(
var: str,
e: Exception | str,
task_name: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `var` | `str` | |
| `e` | `Exception \| str` | |
| `task_name` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/runtimesystemerror ===
# RuntimeSystemError
**Package:** `flyte.errors`
This error is raised when the underlying task execution fails because of a system error. This could be a bug in the
Union system or a bug in the user's code.
## Parameters
```python
class RuntimeSystemError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/runtimeunknownerror ===
# RuntimeUnknownError
**Package:** `flyte.errors`
This error is raised when the underlying task execution fails because of an unknown error.
## Parameters
```python
class RuntimeUnknownError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/runtimeusererror ===
# RuntimeUserError
**Package:** `flyte.errors`
This error is raised when the underlying task execution fails because of an error in the user's code.
## Parameters
```python
class RuntimeUserError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/slowdownerror ===
# SlowDownError
**Package:** `flyte.errors`
This error is raised when the user tries to access a resource that does not exist or is invalid.
## Parameters
```python
class SlowDownError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/taskinterruptederror ===
# TaskInterruptedError
**Package:** `flyte.errors`
This error is raised when the underlying task execution is interrupted.
## Parameters
```python
class TaskInterruptedError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/tasktimeouterror ===
# TaskTimeoutError
**Package:** `flyte.errors`
This error is raised when the underlying task execution runs for longer than the specified timeout.
## Parameters
```python
class TaskTimeoutError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/tracedoesnotallownestedtaskserror ===
# TraceDoesNotAllowNestedTasksError
**Package:** `flyte.errors`
This error is raised when the user tries to use a task from within a trace. Tasks can be nested under tasks
not traces.
## Parameters
```python
class TraceDoesNotAllowNestedTasksError(
message: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `message` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors/unionrpcerror ===
# UnionRpcError
**Package:** `flyte.errors`
This error is raised when communication with the Union server fails.
## Parameters
```python
class UnionRpcError(
code: str,
message: str,
worker: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `code` | `str` | |
| `message` | `str` | |
| `worker` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend ===
# flyte.extend
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate** | A task template that wraps an asynchronous functions. |
| **Flyte SDK > Packages > flyte.extend > ImageBuildEngine** | ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate** | Task template is a template for a task that can be executed. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > ImageBuilder** | |
| **Flyte SDK > Packages > flyte.extend > ImageChecker** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > Methods > download_code_bundle()** | Downloads the code bundle if it is not already downloaded. |
| **Flyte SDK > Packages > flyte.extend > Methods > get_proto_resources()** | Get main resources IDL representation from the resources object. |
| **Flyte SDK > Packages > flyte.extend > Methods > is_initialized()** | Check if the system has been initialized. |
| **Flyte SDK > Packages > flyte.extend > Methods > pod_spec_from_resources()** | |
### Variables
| Property | Type | Description |
|-|-|-|
| `PRIMARY_CONTAINER_DEFAULT_NAME` | `str` | |
| `TaskPluginRegistry` | `_Registry` | |
## Methods
#### download_code_bundle()
```python
def download_code_bundle(
code_bundle: flyte.models.CodeBundle,
) -> flyte.models.CodeBundle
```
Downloads the code bundle if it is not already downloaded.
| Parameter | Type | Description |
|-|-|-|
| `code_bundle` | `flyte.models.CodeBundle` | The code bundle to download. |
**Returns:** The code bundle with the downloaded path.
#### get_proto_resources()
```python
def get_proto_resources(
resources: flyte._resources.Resources | None,
) -> typing.Optional[flyteidl2.core.tasks_pb2.Resources]
```
Get main resources IDL representation from the resources object
| Parameter | Type | Description |
|-|-|-|
| `resources` | `flyte._resources.Resources \| None` | User facing Resources object containing potentially both requests and limits |
**Returns:** The given resources as requests and limits
#### is_initialized()
```python
def is_initialized()
```
Check if the system has been initialized.
**Returns:** True if initialized, False otherwise
#### pod_spec_from_resources()
```python
def pod_spec_from_resources(
primary_container_name: str,
requests: typing.Optional[flyte._resources.Resources],
limits: typing.Optional[flyte._resources.Resources],
k8s_gpu_resource_key: str,
) -> V1PodSpec
```
| Parameter | Type | Description |
|-|-|-|
| `primary_container_name` | `str` | |
| `requests` | `typing.Optional[flyte._resources.Resources]` | |
| `limits` | `typing.Optional[flyte._resources.Resources]` | |
| `k8s_gpu_resource_key` | `str` | |
## Subpages
- **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate**
- **Flyte SDK > Packages > flyte.extend > ImageBuildEngine**
- **Flyte SDK > Packages > flyte.extend > ImageBuilder**
- **Flyte SDK > Packages > flyte.extend > ImageChecker**
- **Flyte SDK > Packages > flyte.extend > TaskTemplate**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend/asyncfunctiontasktemplate ===
# AsyncFunctionTaskTemplate
**Package:** `flyte.extend`
A task template that wraps an asynchronous functions. This is automatically created when an asynchronous function
is decorated with the task decorator.
## Parameters
```python
class AsyncFunctionTaskTemplate(
name: str,
interface: NativeInterface,
short_name: str,
task_type: str,
task_type_version: int,
image: Union[str, Image, Literal['auto']] | None,
resources: Optional[Resources],
cache: CacheRequest,
interruptible: bool,
retries: Union[int, RetryStrategy],
reusable: Union[ReusePolicy, None],
docs: Optional[Documentation],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
timeout: Optional[TimeoutType],
pod_template: Optional[Union[str, PodTemplate]],
report: bool,
queue: Optional[str],
debuggable: bool,
parent_env: Optional[weakref.ReferenceType[TaskEnvironment]],
parent_env_name: Optional[str],
max_inline_io_bytes: int,
triggers: Tuple[Trigger, ...],
links: Tuple[Link, ...],
_call_as_synchronous: bool,
func: F,
plugin_config: Optional[Any],
task_resolver: Optional[Any],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `interface` | `NativeInterface` | |
| `short_name` | `str` | |
| `task_type` | `str` | |
| `task_type_version` | `int` | |
| `image` | `Union[str, Image, Literal['auto']] \| None` | |
| `resources` | `Optional[Resources]` | |
| `cache` | `CacheRequest` | |
| `interruptible` | `bool` | |
| `retries` | `Union[int, RetryStrategy]` | |
| `reusable` | `Union[ReusePolicy, None]` | |
| `docs` | `Optional[Documentation]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `timeout` | `Optional[TimeoutType]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `report` | `bool` | |
| `queue` | `Optional[str]` | |
| `debuggable` | `bool` | |
| `parent_env` | `Optional[weakref.ReferenceType[TaskEnvironment]]` | |
| `parent_env_name` | `Optional[str]` | |
| `max_inline_io_bytes` | `int` | |
| `triggers` | `Tuple[Trigger, ...]` | |
| `links` | `Tuple[Link, ...]` | |
| `_call_as_synchronous` | `bool` | |
| `func` | `F` | |
| `plugin_config` | `Optional[Any]` | |
| `task_resolver` | `Optional[Any]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `json_schema` | `None` | JSON schema for the task inputs, following the Flyte standard. Delegates to NativeInterface.json_schema, which uses the type engine to produce a LiteralType per input and converts to JSON schema. |
| `native_interface` | `None` | |
| `source_file` | `None` | Returns the source file of the function, if available. This is useful for debugging and tracing. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > container_args()** | Returns the container args for the task. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > custom_config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > Methods > execute()** | This is the execute method that will be called when the task is invoked. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > forward()** | Think of this as a local execute method for your task. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > override()** | Override various parameters of the task template. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > post()** | This is the postexecute function that will be. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > pre()** | This is the preexecute function that will be. |
| **Flyte SDK > Packages > flyte.extend > AsyncFunctionTaskTemplate > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
args: *args,
kwargs: **kwargs,
) -> R
```
This is the execute method that will be called when the task is invoked. It will call the actual function.
# TODO We may need to keep this as the bare func execute, and need a pre and post execute some other func.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
when not in a Flyte task execution context. See the implementation below for an example.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: SerializationContext,
) -> Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend/imagebuildengine ===
# ImageBuildEngine
**Package:** `flyte.extend`
ImageBuildEngine contains a list of builders that can be used to build an ImageSpec.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend/imagebuilder ===
# ImageBuilder
**Package:** `flyte.extend`
```python
protocol ImageBuilder()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > ImageBuilder > Methods > build_image()** | |
| **Flyte SDK > Packages > flyte.extend > ImageBuilder > Methods > get_checkers()** | Returns ImageCheckers that can be used to check if the image exists in the registry. |
### build_image()
```python
def build_image(
image: Image,
dry_run: bool,
wait: bool,
force: bool,
) -> 'ImageBuild'
```
| Parameter | Type | Description |
|-|-|-|
| `image` | `Image` | |
| `dry_run` | `bool` | |
| `wait` | `bool` | |
| `force` | `bool` | |
### get_checkers()
```python
def get_checkers()
```
Returns ImageCheckers that can be used to check if the image exists in the registry.
If None, then use the default checkers.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend/imagechecker ===
# ImageChecker
**Package:** `flyte.extend`
```python
protocol ImageChecker()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > ImageChecker > Methods > image_exists()** | |
### image_exists()
```python
def image_exists(
repository: str,
tag: str,
arch: Tuple[Architecture, ...],
) -> Optional[str]
```
| Parameter | Type | Description |
|-|-|-|
| `repository` | `str` | |
| `tag` | `str` | |
| `arch` | `Tuple[Architecture, ...]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend/tasktemplate ===
# TaskTemplate
**Package:** `flyte.extend`
Task template is a template for a task that can be executed. It defines various parameters for the task, which
can be defined statically at the time of task definition or dynamically at the time of task invocation using
the override method.
Example usage:
```python
@task(name="my_task", image="my_image", resources=Resources(cpu="1", memory="1Gi"))
def my_task():
pass
```
## Parameters
```python
class TaskTemplate(
name: str,
interface: NativeInterface,
short_name: str,
task_type: str,
task_type_version: int,
image: Union[str, Image, Literal['auto']] | None,
resources: Optional[Resources],
cache: CacheRequest,
interruptible: bool,
retries: Union[int, RetryStrategy],
reusable: Union[ReusePolicy, None],
docs: Optional[Documentation],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
timeout: Optional[TimeoutType],
pod_template: Optional[Union[str, PodTemplate]],
report: bool,
queue: Optional[str],
debuggable: bool,
parent_env: Optional[weakref.ReferenceType[TaskEnvironment]],
parent_env_name: Optional[str],
max_inline_io_bytes: int,
triggers: Tuple[Trigger, ...],
links: Tuple[Link, ...],
_call_as_synchronous: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Optional The name of the task (defaults to the function name) |
| `interface` | `NativeInterface` | |
| `short_name` | `str` | |
| `task_type` | `str` | Router type for the task, this is used to determine how the task will be executed. This is usually set to match with th execution plugin. |
| `task_type_version` | `int` | |
| `image` | `Union[str, Image, Literal['auto']] \| None` | Optional The image to use for the task, if set to "auto" will use the default image for the python version with flyte installed |
| `resources` | `Optional[Resources]` | Optional The resources to use for the task |
| `cache` | `CacheRequest` | Optional The cache policy for the task, defaults to auto, which will cache the results of the task. |
| `interruptible` | `bool` | Optional The interruptible policy for the task, defaults to False, which means the task will not be scheduled on interruptible nodes. If set to True, the task will be scheduled on interruptible nodes, and the code should handle interruptions and resumptions. |
| `retries` | `Union[int, RetryStrategy]` | Optional The number of retries for the task, defaults to 0, which means no retries. |
| `reusable` | `Union[ReusePolicy, None]` | Optional The reusability policy for the task, defaults to None, which means the task environment will not be reused across task invocations. |
| `docs` | `Optional[Documentation]` | Optional The documentation for the task, if not provided the function docstring will be used. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional The environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional The secrets that will be injected into the task at runtime. |
| `timeout` | `Optional[TimeoutType]` | Optional The timeout for the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional The pod template to use for the task. |
| `report` | `bool` | Optional Whether to report the task execution to the Flyte console, defaults to False. |
| `queue` | `Optional[str]` | Optional The queue to use for the task. If not provided, the default queue will be used. |
| `debuggable` | `bool` | Optional Whether the task supports debugging capabilities, defaults to False. |
| `parent_env` | `Optional[weakref.ReferenceType[TaskEnvironment]]` | |
| `parent_env_name` | `Optional[str]` | |
| `max_inline_io_bytes` | `int` | Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes. |
| `triggers` | `Tuple[Trigger, ...]` | |
| `links` | `Tuple[Link, ...]` | |
| `_call_as_synchronous` | `bool` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `native_interface` | `None` | |
| `source_file` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > container_args()** | Returns the container args for the task. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > custom_config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > execute()** | This is the pure python function that will be executed when the task is called. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > forward()** | Think of this as a local execute method for your task. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > override()** | Override various parameters of the task template. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > post()** | This is the postexecute function that will be. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > pre()** | This is the preexecute function that will be. |
| **Flyte SDK > Packages > flyte.extend > TaskTemplate > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
sctx: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
args,
kwargs,
) -> Any
```
This is the pure python function that will be executed when the task is called.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
when not in a Flyte task execution context. See the implementation below for an example.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: SerializationContext,
) -> Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras ===
# flyte.extras
Flyte extras package.
This package provides various utilities that make it possible to build highly customized workflows.
1. ContainerTask: Execute arbitrary pre-containerized applications, without needing the `flyte-sdk`
to be installed. This extra uses `flyte copilot` system to inject inputs and slurp
outputs from the container run.
2. DynamicBatcher / TokenBatcher: Maximize resource utilization by batching work from many concurrent
producers through a single async processing function. DynamicBatcher is the
general-purpose base; TokenBatcher is a convenience subclass for token-budgeted
LLM inference with reusable containers.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > BatchStats** | Monitoring statistics exposed by `DynamicBatcher. |
| [`ContainerTask`](containertask/page.md) | This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
| [`DynamicBatcher`](dynamicbatcher/page.md) | Batches records from many concurrent producers and runs them through. |
| [`Prompt`](prompt/page.md) | Simple prompt record with built-in token estimation. |
| [`TokenBatcher`](tokenbatcher/page.md) | Token-aware batcher for LLM inference workloads. |
### Protocols
| Protocol | Description |
|-|-|
| [`CostEstimator`](costestimator/page.md) | Protocol for records that can estimate their own processing cost. |
| [`TokenEstimator`](tokenestimator/page.md) | Protocol for records that can estimate their own token count. |
## Subpages
- **Flyte SDK > Packages > flyte.extras > BatchStats**
- **Flyte SDK > Packages > flyte.extras > ContainerTask**
- **Flyte SDK > Packages > flyte.extras > CostEstimator**
- **Flyte SDK > Packages > flyte.extras > DynamicBatcher**
- **Flyte SDK > Packages > flyte.extras > Prompt**
- **Flyte SDK > Packages > flyte.extras > TokenBatcher**
- **Flyte SDK > Packages > flyte.extras > TokenEstimator**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/batchstats ===
# BatchStats
**Package:** `flyte.extras`
Monitoring statistics exposed by `DynamicBatcher.stats`.
Attributes:
total_submitted: Total records submitted via `submit`.
total_completed: Total records whose futures have been resolved.
total_batches: Number of batches dispatched.
total_batch_cost: Sum of estimated cost across all batches.
avg_batch_size: Running average records per batch.
avg_batch_cost: Running average cost per batch.
busy_time_s: Cumulative seconds spent inside `process_fn`.
idle_time_s: Cumulative seconds the processing loop waited for
a batch to be assembled.
## Parameters
```python
class BatchStats(
total_submitted: int,
total_completed: int,
total_batches: int,
total_batch_cost: int,
avg_batch_size: float,
avg_batch_cost: float,
busy_time_s: float,
idle_time_s: float,
)
```
| Parameter | Type | Description |
|-|-|-|
| `total_submitted` | `int` | |
| `total_completed` | `int` | |
| `total_batches` | `int` | |
| `total_batch_cost` | `int` | |
| `avg_batch_size` | `float` | |
| `avg_batch_cost` | `float` | |
| `busy_time_s` | `float` | |
| `idle_time_s` | `float` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `utilization` | `None` | Fraction of wall-clock time spent processing (0.0-1.0). |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/containertask ===
# ContainerTask
**Package:** `flyte.extras`
This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast
majority of tasks - the typical `@task` decorated tasks; for instance, all run a container. An example of
something that doesn't run a container would be something like the Athena SQL task.
## Parameters
```python
class ContainerTask(
name: str,
image: typing.Union[str, flyte._image.Image],
command: typing.List[str],
inputs: typing.Optional[typing.Dict[str, typing.Type]],
arguments: typing.Optional[typing.List[str]],
outputs: typing.Optional[typing.Dict[str, typing.Type]],
input_data_dir: str | pathlib.Path,
output_data_dir: str | pathlib.Path,
metadata_format: typing.Literal['JSON', 'YAML', 'PROTO'],
local_logs: bool,
block_network: bool,
kwargs,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name of the task |
| `image` | `typing.Union[str, flyte._image.Image]` | The container image to use for the task. This can be a string or an Image object. |
| `command` | `typing.List[str]` | The command to run in the container. This can be a list of strings or a single string. |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Type]]` | The inputs to the task. This is a dictionary of input names to types. |
| `arguments` | `typing.Optional[typing.List[str]]` | The arguments to pass to the command. This is a list of strings. |
| `outputs` | `typing.Optional[typing.Dict[str, typing.Type]]` | The outputs of the task. This is a dictionary of output names to types. |
| `input_data_dir` | `str \| pathlib.Path` | The directory where the input data is stored. This is a string or a Path object. |
| `output_data_dir` | `str \| pathlib.Path` | The directory where the output data is stored. This is a string or a Path object. |
| `metadata_format` | `typing.Literal['JSON', 'YAML', 'PROTO']` | The format of the output file. This can be "JSON", "YAML", or "PROTO". |
| `local_logs` | `bool` | If True, logs will be printed to the console in the local execution. |
| `block_network` | `bool` | |
| `kwargs` | `**kwargs` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `native_interface` | `None` | |
| `source_file` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > config()** | Return the configuration for the container task, including network settings. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > container_args()** | Returns the container args for the task. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > custom_config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > execute()** | This is the pure python function that will be executed when the task is called. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > forward()** | Think of this as a local execute method for your task. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > override()** | Override various parameters of the task template. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > post()** | This is the postexecute function that will be. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > pre()** | This is the preexecute function that will be. |
| **Flyte SDK > Packages > flyte.extras > ContainerTask > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: flyte.models.SerializationContext,
) -> typing.Dict[str, str]
```
Return the configuration for the container task, including network settings.
This is for remote execution.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
### container_args()
```python
def container_args(
sctx: flyte.models.SerializationContext,
) -> typing.List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: flyte.models.SerializationContext,
) -> flyteidl2.core.tasks_pb2.DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
### execute()
```python
def execute(
kwargs,
) -> typing.Any
```
This is the pure python function that will be executed when the task is called.
| Parameter | Type | Description |
|-|-|-|
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
when not in a Flyte task execution context. See the implementation below for an example.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: SerializationContext,
) -> Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/costestimator ===
# CostEstimator
**Package:** `flyte.extras`
Protocol for records that can estimate their own processing cost.
Implement this on your record type and the batcher will call it
automatically when no explicit `estimated_cost` is passed to
`DynamicBatcher.submit`.
Example::
@dataclass
class ApiRequest:
payload: str
def estimate_cost(self) -> int:
return len(self.payload)
```python
protocol CostEstimator()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > CostEstimator > Methods > estimate_cost()** | |
### estimate_cost()
```python
def estimate_cost()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/dynamicbatcher ===
# DynamicBatcher
**Package:** `flyte.extras`
Batches records from many concurrent producers and runs them through
a single async processing function, maximizing resource utilization.
The batcher runs two internal loops:
1. **Aggregation loop** β drains the submission queue and assembles
cost-budgeted batches, respecting `target_batch_cost`,
`max_batch_size`, and `batch_timeout_s`.
2. **Processing loop** β pulls assembled batches and calls
`process_fn`, resolving each record's `asyncio.Future`.
Type Parameters:
RecordT: The input record type produced by your tasks.
ResultT: The per-record output type returned by `process_fn`.
Args:
process_fn:
`async def f(batch: list[RecordT]) -> list[ResultT]`
Must return results in the **same order** as the input batch.
cost_estimator:
Optional `(RecordT) -> int` function. When provided, it is
called to estimate the cost of each submitted record.
Falls back to `record.estimate_cost()` if the record
implements `CostEstimator`, then to `default_cost`.
target_batch_cost:
Cost budget per batch. The aggregator fills batches up to
this limit before dispatching.
max_batch_size:
Hard cap on records per batch regardless of cost budget.
min_batch_size:
Minimum records before dispatching. Ignored when the timeout
fires or shutdown is in progress.
batch_timeout_s:
Maximum seconds to wait for a full batch. Lower values reduce
idle time but may produce smaller batches.
max_queue_size:
Bounded queue size. When full, `submit` awaits
(backpressure).
prefetch_batches:
Number of pre-assembled batches to buffer between the
aggregation and processing loops.
default_cost:
Fallback cost when no estimator is available.
Example::
async def process(batch: list[dict]) -> list[str]:
...
async with DynamicBatcher(process_fn=process) as batcher:
futures = []
for record in my_records:
f = await batcher.submit(record)
futures.append(f)
results = await asyncio.gather(*futures)
## Parameters
```python
class DynamicBatcher(
process_fn: ProcessFn[RecordT, ResultT],
cost_estimator: CostEstimatorFn[RecordT] | None,
target_batch_cost: int,
max_batch_size: int,
min_batch_size: int,
batch_timeout_s: float,
max_queue_size: int,
prefetch_batches: int,
default_cost: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `process_fn` | `ProcessFn[RecordT, ResultT]` | |
| `cost_estimator` | `CostEstimatorFn[RecordT] \| None` | |
| `target_batch_cost` | `int` | |
| `max_batch_size` | `int` | |
| `min_batch_size` | `int` | |
| `batch_timeout_s` | `float` | |
| `max_queue_size` | `int` | |
| `prefetch_batches` | `int` | |
| `default_cost` | `int` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `is_running` | `None` | Whether the aggregation and processing loops are active. |
| `stats` | `None` | Current `BatchStats` snapshot. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > DynamicBatcher > Methods > start()** | Start the aggregation and processing loops. |
| **Flyte SDK > Packages > flyte.extras > DynamicBatcher > Methods > stop()** | Graceful shutdown: process all enqueued work, then stop. |
| **Flyte SDK > Packages > flyte.extras > DynamicBatcher > Methods > submit()** | Submit a single record for batched processing. |
| **Flyte SDK > Packages > flyte.extras > DynamicBatcher > Methods > submit_batch()** | Convenience: submit multiple records and return their futures. |
### start()
```python
def start()
```
Start the aggregation and processing loops.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If the batcher is already running. |
### stop()
```python
def stop()
```
Graceful shutdown: process all enqueued work, then stop.
Blocks until every pending future is resolved.
### submit()
```python
def submit(
record: RecordT,
estimated_cost: int | None,
) -> asyncio.Future[ResultT]
```
Submit a single record for batched processing.
Returns an `asyncio.Future` that resolves once the batch
containing this record has been processed.
Example::
future = await batcher.submit(my_record, estimated_cost=128)
result = await future
| Parameter | Type | Description |
|-|-|-|
| `record` | `RecordT` | The input record. |
| `estimated_cost` | `int \| None` | Optional explicit cost. When omitted the batcher tries `cost_estimator`, then `record.estimate_cost()`, then `default_cost`. |
**Returns**
A future whose result is the corresponding entry from the list
returned by `process_fn`.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If the batcher is not running. |
> [!NOTE]
> If the internal queue is full this coroutine awaits until space
> is available, providing natural backpressure to fast producers.
### submit_batch()
```python
def submit_batch(
records: Sequence[RecordT],
estimated_cost: Sequence[int] | None,
) -> list[asyncio.Future[ResultT]]
```
Convenience: submit multiple records and return their futures.
| Parameter | Type | Description |
|-|-|-|
| `records` | `Sequence[RecordT]` | Iterable of input records. |
| `estimated_cost` | `Sequence[int] \| None` | Optional per-record cost estimates. Length must match *records* when provided. |
**Returns:** List of futures, one per record.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/prompt ===
# Prompt
**Package:** `flyte.extras`
Simple prompt record with built-in token estimation.
This is a convenience type for common LLM use cases. For richer
prompt types (e.g. with system messages, metadata), define your own
dataclass implementing `TokenEstimator`.
Attributes:
text: The prompt text.
## Parameters
```python
class Prompt(
text: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `text` | `str` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > Prompt > Methods > estimate_tokens()** | Rough token estimate (~4 chars per token). |
### estimate_tokens()
```python
def estimate_tokens()
```
Rough token estimate (~4 chars per token).
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/tokenbatcher ===
# TokenBatcher
**Package:** `flyte.extras`
Token-aware batcher for LLM inference workloads.
A thin convenience wrapper around `DynamicBatcher` that accepts
token-specific parameter names (`inference_fn`, `token_estimator`,
`target_batch_tokens`, etc.) and maps them to the base class.
Also checks the `TokenEstimator` protocol (`estimate_tokens()`)
in addition to `CostEstimator` (`estimate_cost()`).
Example::
async def inference(batch: list[Prompt]) -> list[str]:
...
async with TokenBatcher(inference_fn=inference) as batcher:
future = await batcher.submit(Prompt(text="Hello"))
result = await future
## Parameters
```python
class TokenBatcher(
inference_fn: ProcessFn[RecordT, ResultT] | None,
process_fn: ProcessFn[RecordT, ResultT] | None,
token_estimator: CostEstimatorFn[RecordT] | None,
cost_estimator: CostEstimatorFn[RecordT] | None,
target_batch_tokens: int | None,
target_batch_cost: int,
default_token_estimate: int | None,
default_cost: int,
max_batch_size: int,
min_batch_size: int,
batch_timeout_s: float,
max_queue_size: int,
prefetch_batches: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `inference_fn` | `ProcessFn[RecordT, ResultT] \| None` | |
| `process_fn` | `ProcessFn[RecordT, ResultT] \| None` | |
| `token_estimator` | `CostEstimatorFn[RecordT] \| None` | |
| `cost_estimator` | `CostEstimatorFn[RecordT] \| None` | |
| `target_batch_tokens` | `int \| None` | |
| `target_batch_cost` | `int` | |
| `default_token_estimate` | `int \| None` | |
| `default_cost` | `int` | |
| `max_batch_size` | `int` | |
| `min_batch_size` | `int` | |
| `batch_timeout_s` | `float` | |
| `max_queue_size` | `int` | |
| `prefetch_batches` | `int` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `is_running` | `None` | Whether the aggregation and processing loops are active. |
| `stats` | `None` | Current `BatchStats` snapshot. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > TokenBatcher > Methods > start()** | Start the aggregation and processing loops. |
| **Flyte SDK > Packages > flyte.extras > TokenBatcher > Methods > stop()** | Graceful shutdown: process all enqueued work, then stop. |
| **Flyte SDK > Packages > flyte.extras > TokenBatcher > Methods > submit()** | Submit a single record for batched inference. |
| **Flyte SDK > Packages > flyte.extras > TokenBatcher > Methods > submit_batch()** | Convenience: submit multiple records and return their futures. |
### start()
```python
def start()
```
Start the aggregation and processing loops.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If the batcher is already running. |
### stop()
```python
def stop()
```
Graceful shutdown: process all enqueued work, then stop.
Blocks until every pending future is resolved.
### submit()
```python
def submit(
record: RecordT,
estimated_tokens: int | None,
estimated_cost: int | None,
) -> asyncio.Future[ResultT]
```
Submit a single record for batched inference.
Accepts either `estimated_tokens` or `estimated_cost`.
| Parameter | Type | Description |
|-|-|-|
| `record` | `RecordT` | The input record. |
| `estimated_tokens` | `int \| None` | Optional explicit token count. |
| `estimated_cost` | `int \| None` | Optional explicit cost (base class parameter). |
**Returns**
A future whose result is the corresponding entry from the list
returned by the inference function.
### submit_batch()
```python
def submit_batch(
records: Sequence[RecordT],
estimated_cost: Sequence[int] | None,
) -> list[asyncio.Future[ResultT]]
```
Convenience: submit multiple records and return their futures.
| Parameter | Type | Description |
|-|-|-|
| `records` | `Sequence[RecordT]` | Iterable of input records. |
| `estimated_cost` | `Sequence[int] \| None` | Optional per-record cost estimates. Length must match *records* when provided. |
**Returns:** List of futures, one per record.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras/tokenestimator ===
# TokenEstimator
**Package:** `flyte.extras`
Protocol for records that can estimate their own token count.
Implement this on your record type and the `TokenBatcher` will
call it automatically when no explicit `estimated_tokens` is passed
to `TokenBatcher.submit`.
Example::
@dataclass
class Prompt:
text: str
def estimate_tokens(self) -> int:
return len(self.text) // 4 + 1
```python
protocol TokenEstimator()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > TokenEstimator > Methods > estimate_tokens()** | |
### estimate_tokens()
```python
def estimate_tokens()
```
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.git ===
# flyte.git
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus** | A class representing the status of a git repository. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > Methods > config_from_root()** | Get the config file from the git root directory. |
## Methods
#### config_from_root()
```python
def config_from_root(
path: pathlib.Path | str,
) -> flyte.config._config.Config | None
```
Get the config file from the git root directory.
By default, the config file is expected to be in `.flyte/config.yaml` in the git root directory.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib.Path \| str` | Path to the config file relative to git root directory (default |
**Returns:** Config object if found, None otherwise
## Subpages
- **Flyte SDK > Packages > flyte.git > GitStatus**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.git/gitstatus ===
# GitStatus
**Package:** `flyte.git`
A class representing the status of a git repository.
## Parameters
```python
class GitStatus(
is_valid: bool,
is_tree_clean: bool,
remote_url: str,
repo_dir: pathlib.Path,
commit_sha: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `is_valid` | `bool` | Whether git repository is valid |
| `is_tree_clean` | `bool` | Whether working tree is clean |
| `remote_url` | `str` | Remote URL in HTTPS format |
| `repo_dir` | `pathlib.Path` | Repository root directory |
| `commit_sha` | `str` | Current commit SHA |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus > Methods > build_url()** | Build a git URL for the given path. |
| **Flyte SDK > Packages > flyte.git > GitStatus > Methods > from_current_repo()** | Discover git information from the current repository. |
### build_url()
```python
def build_url(
path: pathlib.Path | str,
line_number: int,
) -> str
```
Build a git URL for the given path.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib.Path \| str` | Path to a file |
| `line_number` | `int` | Line number of the code file |
**Returns:** Path relative to repo_dir
### from_current_repo()
```python
def from_current_repo()
```
Discover git information from the current repository.
If Git is not installed or .git does not exist, returns GitStatus with is_valid=False.
**Returns:** GitStatus instance with discovered git information
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io ===
# flyte.io
## IO data types
This package contains additional data types beyond the primitive data types in python to abstract data flow
of large datasets in Union.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > DataFrame** | A Flyte meta DataFrame object, that wraps all other dataframe types (usually available as plugins, pandas. |
| **Flyte SDK > Packages > flyte.io > Dir** | A generic directory class representing a directory with files of a specified format. |
| **Flyte SDK > Packages > flyte.io > File** | A generic file class representing a file with a specified format. |
| **Flyte SDK > Packages > flyte.io > HashFunction** | A hash method that wraps a user-provided function to compute hashes. |
### Variables
| Property | Type | Description |
|-|-|-|
| `PARQUET` | `str` | |
## Subpages
- **Flyte SDK > Packages > flyte.io > DataFrame**
- **Flyte SDK > Packages > flyte.io > Dir**
- **Flyte SDK > Packages > flyte.io > File**
- **Flyte SDK > Packages > flyte.io > HashFunction**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io/dataframe ===
# DataFrame
**Package:** `flyte.io`
A Flyte meta DataFrame object, that wraps all other dataframe types (usually available as plugins, pandas.DataFrame
and pyarrow.Table are supported natively, just install these libraries).
Known eco-system plugins that supply other dataframe encoding plugins are,
1. `flyteplugins-polars` - pl.DataFrame
2. `flyteplugins-spark` - pyspark.DataFrame
You can add other implementations by extending following `flyte.io.extend`.
The Flyte DataFrame object serves 2 main purposes:
1. Interoperability between various dataframe objects. A task can generate a pandas.DataFrame and another task
can accept a flyte.io.DataFrame, which can be converted to any dataframe.
2. Allows for non materialized access to DataFrame objects. So, for example you can accept any dataframe as a
flyte.io.DataFrame and this is just a reference and will not materialize till you force `.all()` or `.iter()` etc
## Parameters
```python
class DataFrame(
uri: typing.Optional[str],
format: typing.Optional[str],
hash: typing.Optional[str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `uri` | `typing.Optional[str]` | |
| `format` | `typing.Optional[str]` | |
| `hash` | `typing.Optional[str]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `lazy_uploader` | `None` | |
| `literal` | `None` | |
| `metadata` | `None` | |
| `val` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > all()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > all_sync()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > column_names()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > columns()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > deserialize_dataframe()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > from_df()** | Deprecated: Please use wrap_df, as that is the right name. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > from_existing_remote()** | Create a DataFrame reference from an existing remote dataframe. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > from_local()** | This method is useful to upload the dataframe eagerly and get the actual DataFrame. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > from_local_sync()** | This method is useful to upload the dataframe eagerly and get the actual DataFrame. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > iter()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > model_post_init()** | This function is meant to behave like a BaseModel method to initialise private attributes. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > open()** | Load the handler if needed. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > schema_match()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > serialize_dataframe()** | |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > set_literal()** | A public wrapper method to set the DataFrame Literal. |
| **Flyte SDK > Packages > flyte.io > DataFrame > Methods > wrap_df()** | Wrapper to create a DataFrame from a dataframe. |
### all()
```python
def all()
```
### all_sync()
```python
def all_sync()
```
### column_names()
```python
def column_names()
```
### columns()
```python
def columns()
```
### deserialize_dataframe()
```python
def deserialize_dataframe(
info,
) -> DataFrame
```
| Parameter | Type | Description |
|-|-|-|
| `info` | | |
### from_df()
```python
def from_df(
val: typing.Optional[typing.Any],
uri: typing.Optional[str],
) -> DataFrame
```
Deprecated: Please use wrap_df, as that is the right name.
Creates a new Flyte DataFrame from any registered DataFrame type (For example, pandas.DataFrame).
Other dataframe types are usually supported through plugins like `flyteplugins-polars`, `flyteplugins-spark`
etc.
| Parameter | Type | Description |
|-|-|-|
| `val` | `typing.Optional[typing.Any]` | |
| `uri` | `typing.Optional[str]` | |
### from_existing_remote()
```python
def from_existing_remote(
remote_path: str,
format: typing.Optional[str],
kwargs,
) -> 'DataFrame'
```
Create a DataFrame reference from an existing remote dataframe.
Example:
```python
df = DataFrame.from_existing_remote("s3://bucket/data.parquet", format="parquet")
```
| Parameter | Type | Description |
|-|-|-|
| `remote_path` | `str` | The remote path to the existing dataframe |
| `format` | `typing.Optional[str]` | Format of the stored dataframe |
| `kwargs` | `**kwargs` | |
### from_local()
```python
def from_local(
df: typing.Any,
columns: typing.OrderedDict[str, type[typing.Any]] | None,
remote_destination: str | None,
hash_method: HashMethod | str | None,
) -> DataFrame
```
This method is useful to upload the dataframe eagerly and get the actual DataFrame.
This is useful to upload small local datasets onto Flyte and also upload dataframes from notebooks. This
uses signed urls and is thus not the most efficient way of uploading.
In tasks (at runtime) it uses the task context and the underlying fast storage sub-system to upload the data.
At runtime it is recommended to use `DataFrame.wrap_df` as it is simpler.
Example (With hash_method for cache key computation):
```python
import pandas as pd
from flyte.io import DataFrame, HashFunction
def hash_pandas_dataframe(df: pd.DataFrame) -> str:
return str(pd.util.hash_pandas_object(df).sum())
@env.task
async def foo() -> DataFrame:
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
hash_method = HashFunction.from_fn(hash_pandas_dataframe)
return await DataFrame.from_local(df, hash_method=hash_method)
```
| Parameter | Type | Description |
|-|-|-|
| `df` | `typing.Any` | The dataframe object to be uploaded and converted. |
| `columns` | `typing.OrderedDict[str, type[typing.Any]] \| None` | Optionally, any column information to be stored as part of the metadata |
| `remote_destination` | `str \| None` | Optional destination URI to upload to, if not specified, this is automatically determined based on the current context. For example, locally it will use flyte:// automatic data management system to upload data (this is slow and useful for smaller datasets). On remote it will use the storage configuration and the raw data directory setting in the task context. |
| `hash_method` | `HashMethod \| str \| None` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash from the dataframe. If not specified, the cache key will be based on dataframe attributes. Returns: DataFrame object. |
### from_local_sync()
```python
def from_local_sync(
df: typing.Any,
columns: typing.OrderedDict[str, type[typing.Any]] | None,
remote_destination: str | None,
hash_method: HashMethod | str | None,
) -> DataFrame
```
This method is useful to upload the dataframe eagerly and get the actual DataFrame.
This is useful to upload small local datasets onto Flyte and also upload dataframes from notebooks. This
uses signed urls and is thus not the most efficient way of uploading.
In tasks (at runtime) it uses the task context and the underlying fast storage sub-system to upload the data.
At runtime it is recommended to use `DataFrame.wrap_df` as it is simpler.
Example (With hash_method for cache key computation):
```python
import pandas as pd
from flyte.io import DataFrame, HashFunction
def hash_pandas_dataframe(df: pd.DataFrame) -> str:
return str(pd.util.hash_pandas_object(df).sum())
@env.task
def foo() -> DataFrame:
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
hash_method = HashFunction.from_fn(hash_pandas_dataframe)
return DataFrame.from_local_sync(df, hash_method=hash_method)
```
| Parameter | Type | Description |
|-|-|-|
| `df` | `typing.Any` | The dataframe object to be uploaded and converted. |
| `columns` | `typing.OrderedDict[str, type[typing.Any]] \| None` | Optionally, any column information to be stored as part of the metadata |
| `remote_destination` | `str \| None` | Optional destination URI to upload to, if not specified, this is automatically determined based on the current context. For example, locally it will use flyte:// automatic data management system to upload data (this is slow and useful for smaller datasets). On remote it will use the storage configuration and the raw data directory setting in the task context. |
| `hash_method` | `HashMethod \| str \| None` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash from the dataframe. If not specified, the cache key will be based on dataframe attributes. Returns: DataFrame object. |
### iter()
```python
def iter()
```
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | The context. |
### open()
```python
def open(
dataframe_type: Type[DF],
)
```
Load the handler if needed. For the use case like:
@task
def t1(df: DataFrame):
import pandas as pd
df.open(pd.DataFrame).all()
pandas is imported inside the task, so panda handler won't be loaded during deserialization in type engine.
| Parameter | Type | Description |
|-|-|-|
| `dataframe_type` | `Type[DF]` | |
### schema_match()
```python
def schema_match(
incoming: dict,
) -> bool
```
| Parameter | Type | Description |
|-|-|-|
| `incoming` | `dict` | |
### serialize_dataframe()
```python
def serialize_dataframe()
```
### set_literal()
```python
def set_literal(
expected: types_pb2.LiteralType,
)
```
A public wrapper method to set the DataFrame Literal.
This method provides external access to the internal _set_literal method.
| Parameter | Type | Description |
|-|-|-|
| `expected` | `types_pb2.LiteralType` | |
### wrap_df()
```python
def wrap_df(
val: typing.Optional[typing.Any],
uri: typing.Optional[str],
) -> DataFrame
```
Wrapper to create a DataFrame from a dataframe.
Other dataframe types are usually supported through plugins like `flyteplugins-polars`, `flyteplugins-spark`
etc.
| Parameter | Type | Description |
|-|-|-|
| `val` | `typing.Optional[typing.Any]` | |
| `uri` | `typing.Optional[str]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io/dir ===
# Dir
**Package:** `flyte.io`
A generic directory class representing a directory with files of a specified format.
Provides both async and sync interfaces for directory operations. All methods without _sync suffix are async.
The class should be instantiated using one of the class methods. The constructor should only be used to
instantiate references to existing remote directories.
The generic type T represents the format of the files in the directory.
Important methods:
- `from_existing_remote`: Create a Dir object referencing an existing remote directory.
- `from_local` / `from_local_sync`: Upload a local directory to remote storage.
**Asynchronous methods**:
- `walk`: Asynchronously iterate through files in the directory.
- `list_files`: Asynchronously get a list of all files (non-recursive).
- `download`: Asynchronously download the entire directory to a local path.
- `exists`: Asynchronously check if the directory exists.
- `get_file`: Asynchronously get a specific file from the directory by name.
**Synchronous methods** (suffixed with `_sync`):
- `walk_sync`: Synchronously iterate through files in the directory.
- `list_files_sync`: Synchronously get a list of all files (non-recursive).
- `download_sync`: Synchronously download the entire directory to a local path.
- `exists_sync`: Synchronously check if the directory exists.
- `get_file_sync`: Synchronously get a specific file from the directory by name.
Example: Walk through directory files recursively (Async).
```python
@env.task
async def process_all_files(d: Dir) -> int:
file_count = 0
async for file in d.walk(recursive=True):
async with file.open("rb") as f:
content = await f.read()
# Process content
file_count += 1
return file_count
```
Example: Walk through directory files recursively (Sync).
```python
@env.task
def process_all_files_sync(d: Dir) -> int:
file_count = 0
for file in d.walk_sync(recursive=True):
with file.open_sync("rb") as f:
content = f.read()
# Process content
file_count += 1
return file_count
```
Example: List files in directory (Async).
```python
@env.task
async def count_files(d: Dir) -> int:
files = await d.list_files()
return len(files)
```
Example: List files in directory (Sync).
```python
@env.task
def count_files_sync(d: Dir) -> int:
files = d.list_files_sync()
return len(files)
```
Example: Get a specific file from directory (Async).
```python
@env.task
async def read_config_file(d: Dir) -> str:
config_file = await d.get_file("config.json")
if config_file:
async with config_file.open("rb") as f:
return (await f.read()).decode("utf-8")
return "Config not found"
```
Example: Get a specific file from directory (Sync).
```python
@env.task
def read_config_file_sync(d: Dir) -> str:
config_file = d.get_file_sync("config.json")
if config_file:
with config_file.open_sync("rb") as f:
return f.read().decode("utf-8")
return "Config not found"
```
Example: Upload a local directory to remote storage (Async).
```python
@env.task
async def upload_directory() -> Dir:
# Create local directory with files
os.makedirs("/tmp/my_data", exist_ok=True)
with open("/tmp/my_data/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
return await Dir.from_local("/tmp/my_data/")
```
Example: Upload a local directory to remote storage (Sync).
```python
@env.task
def upload_directory_sync() -> Dir:
# Create local directory with files
os.makedirs("/tmp/my_data", exist_ok=True)
with open("/tmp/my_data/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
return Dir.from_local_sync("/tmp/my_data/")
```
Example: Download a directory to local storage (Async).
```python
@env.task
async def download_directory(d: Dir) -> str:
local_path = await d.download()
# Process files in local directory
return local_path
```
Example: Download a directory to local storage (Sync).
```python
@env.task
def download_directory_sync(d: Dir) -> str:
local_path = d.download_sync()
# Process files in local directory
return local_path
```
Example: Reference an existing remote directory.
```python
@env.task
async def process_existing_dir() -> int:
d = Dir.from_existing_remote("s3://my-bucket/data/")
files = await d.list_files()
return len(files)
```
Example: Check if directory exists (Async).
```python
@env.task
async def check_directory(d: Dir) -> bool:
return await d.exists()
```
Example: Check if directory exists (Sync).
```python
@env.task
def check_directory_sync(d: Dir) -> bool:
return d.exists_sync()
```
## Parameters
```python
class Dir(
path: str,
name: typing.Optional[str],
format: str,
hash: typing.Optional[str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | The path to the directory (can be local or remote) |
| `name` | `typing.Optional[str]` | Optional name for the directory (defaults to basename of path) |
| `format` | `str` | |
| `hash` | `typing.Optional[str]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `lazy_uploader` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > Dir > Methods > download()** | Asynchronously download the entire directory to a local path. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > download_sync()** | Synchronously download the entire directory to a local path. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > exists()** | Asynchronously check if the directory exists. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > exists_sync()** | Synchronously check if the directory exists. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > from_existing_remote()** | Create a Dir reference from an existing remote directory. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > from_local()** | Asynchronously create a new Dir by uploading a local directory to remote storage. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > from_local_sync()** | Synchronously create a new Dir by uploading a local directory to remote storage. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > get_file()** | Asynchronously get a specific file from the directory by name. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > get_file_sync()** | Synchronously get a specific file from the directory by name. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > list_files()** | Asynchronously get a list of all files in the directory (non-recursive). |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > list_files_sync()** | Synchronously get a list of all files in the directory (non-recursive). |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > model_post_init()** | This function is meant to behave like a BaseModel method to initialise private attributes. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > new_remote()** | Create a new Dir reference for a remote directory that will be written to. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > pre_init()** | Internal: Pydantic validator to set default name from path. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > schema_match()** | Internal: Check if incoming schema matches Dir schema. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > walk()** | Asynchronously walk through the directory and yield File objects. |
| **Flyte SDK > Packages > flyte.io > Dir > Methods > walk_sync()** | Synchronously walk through the directory and yield File objects. |
### download()
```python
def download(
local_path: Optional[Union[str, Path]],
) -> str
```
Asynchronously download the entire directory to a local path.
Use this when you need to download all files in a directory to your local filesystem for processing.
Example (Async):
```python
@env.task
async def download_directory(d: Dir) -> str:
local_dir = await d.download()
# Process files in the local directory
return local_dir
```
Example (Async - Download to specific path):
```python
@env.task
async def download_to_path(d: Dir) -> str:
local_dir = await d.download("/tmp/my_data/")
return local_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the directory to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded directory
### download_sync()
```python
def download_sync(
local_path: Optional[Union[str, Path]],
) -> str
```
Synchronously download the entire directory to a local path.
Use this in non-async tasks when you need to download all files in a directory to your local filesystem.
Example (Sync):
```python
@env.task
def download_directory_sync(d: Dir) -> str:
local_dir = d.download_sync()
# Process files in the local directory
return local_dir
```
Example (Sync - Download to specific path):
```python
@env.task
def download_to_path_sync(d: Dir) -> str:
local_dir = d.download_sync("/tmp/my_data/")
return local_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the directory to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded directory
### exists()
```python
def exists()
```
Asynchronously check if the directory exists.
Example (Async):
```python
@env.task
async def check_directory(d: Dir) -> bool:
if await d.exists():
print("Directory exists!")
return True
return False
```
**Returns**
True if the directory exists, False otherwise
### exists_sync()
```python
def exists_sync()
```
Synchronously check if the directory exists.
Use this in non-async tasks or when you need synchronous directory existence checking.
Example (Sync):
```python
@env.task
def check_directory_sync(d: Dir) -> bool:
if d.exists_sync():
print("Directory exists!")
return True
return False
```
**Returns**
True if the directory exists, False otherwise
### from_existing_remote()
```python
def from_existing_remote(
remote_path: str,
dir_cache_key: Optional[str],
) -> Dir[T]
```
Create a Dir reference from an existing remote directory.
Use this when you want to reference a directory that already exists in remote storage without uploading it.
Example:
```python
@env.task
async def process_existing_directory() -> int:
d = Dir.from_existing_remote("s3://my-bucket/data/")
files = await d.list_files()
return len(files)
```
Example (With cache key):
```python
@env.task
async def process_with_cache_key() -> int:
d = Dir.from_existing_remote("s3://my-bucket/data/", dir_cache_key="abc123")
files = await d.list_files()
return len(files)
```
| Parameter | Type | Description |
|-|-|-|
| `remote_path` | `str` | The remote path to the existing directory |
| `dir_cache_key` | `Optional[str]` | Optional hash value to use for cache key computation. If not specified, the cache key will be computed based on the directory's attributes. |
**Returns:** A new Dir instance pointing to the existing remote directory
### from_local()
```python
def from_local(
local_path: Union[str, Path],
remote_destination: Optional[str],
dir_cache_key: Optional[str],
batch_size: Optional[int],
) -> Dir[T]
```
Asynchronously create a new Dir by uploading a local directory to remote storage.
Use this in async tasks when you have a local directory that needs to be uploaded to remote storage.
Example (Async):
```python
@env.task
async def upload_local_directory() -> Dir:
# Create a local directory with files
os.makedirs("/tmp/data_dir", exist_ok=True)
with open("/tmp/data_dir/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
remote_dir = await Dir.from_local("/tmp/data_dir/")
return remote_dir
```
Example (Async - With specific destination):
```python
@env.task
async def upload_to_specific_path() -> Dir:
remote_dir = await Dir.from_local("/tmp/data_dir/", "s3://my-bucket/data/")
return remote_dir
```
Example (Async - With cache key):
```python
@env.task
async def upload_with_cache_key() -> Dir:
remote_dir = await Dir.from_local("/tmp/data_dir/", dir_cache_key="my_cache_key_123")
return remote_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local directory |
| `remote_destination` | `Optional[str]` | Optional remote path to store the directory. If None, a path will be automatically generated. |
| `dir_cache_key` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. If not specified, the cache key will be based on directory attributes. |
| `batch_size` | `Optional[int]` | Optional concurrency limit for uploading files. If not specified, the default value is determined by the FLYTE_IO_BATCH_SIZE environment variable (default: 32). |
**Returns:** A new Dir instance pointing to the uploaded directory
### from_local_sync()
```python
def from_local_sync(
local_path: Union[str, Path],
remote_destination: Optional[str],
dir_cache_key: Optional[str],
) -> Dir[T]
```
Synchronously create a new Dir by uploading a local directory to remote storage.
Use this in non-async tasks when you have a local directory that needs to be uploaded to remote storage.
Example (Sync):
```python
@env.task
def upload_local_directory_sync() -> Dir:
# Create a local directory with files
os.makedirs("/tmp/data_dir", exist_ok=True)
with open("/tmp/data_dir/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
remote_dir = Dir.from_local_sync("/tmp/data_dir/")
return remote_dir
```
Example (Sync - With specific destination):
```python
@env.task
def upload_to_specific_path_sync() -> Dir:
remote_dir = Dir.from_local_sync("/tmp/data_dir/", "s3://my-bucket/data/")
return remote_dir
```
Example (Sync - With cache key):
```python
@env.task
def upload_with_cache_key_sync() -> Dir:
remote_dir = Dir.from_local_sync("/tmp/data_dir/", dir_cache_key="my_cache_key_123")
return remote_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local directory |
| `remote_destination` | `Optional[str]` | Optional remote path to store the directory. If None, a path will be automatically generated. |
| `dir_cache_key` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. If not specified, the cache key will be based on directory attributes. |
**Returns:** A new Dir instance pointing to the uploaded directory
### get_file()
```python
def get_file(
file_name: str,
) -> Optional[File[T]]
```
Asynchronously get a specific file from the directory by name.
Use this when you know the name of a specific file in the directory you want to access.
Example (Async):
```python
@env.task
async def read_specific_file(d: Dir) -> str:
file = await d.get_file("data.csv")
if file:
async with file.open("rb") as f:
content = await f.read()
return content.decode("utf-8")
return "File not found"
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `str` | The name of the file to get |
**Returns:** A File instance if the file exists, None otherwise
### get_file_sync()
```python
def get_file_sync(
file_name: str,
) -> Optional[File[T]]
```
Synchronously get a specific file from the directory by name.
Use this in non-async tasks when you know the name of a specific file in the directory you want to access.
Example (Sync):
```python
@env.task
def read_specific_file_sync(d: Dir) -> str:
file = d.get_file_sync("data.csv")
if file:
with file.open_sync("rb") as f:
content = f.read()
return content.decode("utf-8")
return "File not found"
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `str` | The name of the file to get |
**Returns:** A File instance if the file exists, None otherwise
### list_files()
```python
def list_files()
```
Asynchronously get a list of all files in the directory (non-recursive).
Use this when you need a list of all files in the top-level directory at once.
Example (Async):
```python
@env.task
async def count_files(d: Dir) -> int:
files = await d.list_files()
return len(files)
```
Example (Async - Process files):
```python
@env.task
async def process_all_files(d: Dir) -> list[str]:
files = await d.list_files()
contents = []
for file in files:
async with file.open("rb") as f:
content = await f.read()
contents.append(content.decode("utf-8"))
return contents
```
**Returns**
A list of File objects for files in the top-level directory
### list_files_sync()
```python
def list_files_sync()
```
Synchronously get a list of all files in the directory (non-recursive).
Use this in non-async tasks when you need a list of all files in the top-level directory at once.
Example (Sync):
```python
@env.task
def count_files_sync(d: Dir) -> int:
files = d.list_files_sync()
return len(files)
```
Example (Sync - Process files):
```python
@env.task
def process_all_files_sync(d: Dir) -> list[str]:
files = d.list_files_sync()
contents = []
for file in files:
with file.open_sync("rb") as f:
content = f.read()
contents.append(content.decode("utf-8"))
return contents
```
**Returns**
A list of File objects for files in the top-level directory
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | The context. |
### new_remote()
```python
def new_remote(
dir_name: Optional[str],
hash: Optional[str],
) -> Dir[T]
```
Create a new Dir reference for a remote directory that will be written to.
Use this when you want to create a new directory and write files into it
directly without creating a local directory first.
Example::
@env.task
async def create() -> Dir:
d = Dir.new_remote("output")
# write files into d ...
return d
| Parameter | Type | Description |
|-|-|-|
| `dir_name` | `Optional[str]` | Optional name for the remote directory. If not set, a generated name will be used. |
| `hash` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. |
**Returns:** A new Dir instance with a generated remote path.
### pre_init()
```python
def pre_init(
data,
)
```
Internal: Pydantic validator to set default name from path. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `data` | | |
### schema_match()
```python
def schema_match(
incoming: dict,
)
```
Internal: Check if incoming schema matches Dir schema. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `incoming` | `dict` | |
### walk()
```python
def walk(
recursive: bool,
max_depth: Optional[int],
) -> AsyncIterator[File[T]]
```
Asynchronously walk through the directory and yield File objects.
Use this to iterate through all files in a directory. Each yielded File can be read directly without
downloading.
Example (Async - Recursive):
```python
@env.task
async def list_all_files(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=True):
file_names.append(file.name)
return file_names
```
Example (Async - Non-recursive):
```python
@env.task
async def list_top_level_files(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=False):
file_names.append(file.name)
return file_names
```
Example (Async - With max depth):
```python
@env.task
async def list_files_max_depth(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=True, max_depth=2):
file_names.append(file.name)
return file_names
```
Yields:
File objects for each file found in the directory
| Parameter | Type | Description |
|-|-|-|
| `recursive` | `bool` | If True, recursively walk subdirectories. If False, only list files in the top-level directory. |
| `max_depth` | `Optional[int]` | Maximum depth for recursive walking. If None, walk through all subdirectories. |
### walk_sync()
```python
def walk_sync(
recursive: bool,
file_pattern: str,
max_depth: Optional[int],
) -> Iterator[File[T]]
```
Synchronously walk through the directory and yield File objects.
Use this in non-async tasks to iterate through all files in a directory.
Example (Sync - Recursive):
```python
@env.task
def list_all_files_sync(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True):
file_names.append(file.name)
return file_names
```
Example (Sync - With file pattern):
```python
@env.task
def list_text_files(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True, file_pattern="*.txt"):
file_names.append(file.name)
return file_names
```
Example (Sync - Non-recursive with max depth):
```python
@env.task
def list_files_limited(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True, max_depth=2):
file_names.append(file.name)
return file_names
```
Yields:
File objects for each file found in the directory
| Parameter | Type | Description |
|-|-|-|
| `recursive` | `bool` | If True, recursively walk subdirectories. If False, only list files in the top-level directory. |
| `file_pattern` | `str` | Glob pattern to filter files (e.g., "*.txt", "*.csv"). Default is "*" (all files). |
| `max_depth` | `Optional[int]` | Maximum depth for recursive walking. If None, walk through all subdirectories. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io/file ===
# File
**Package:** `flyte.io`
A generic file class representing a file with a specified format.
Provides both async and sync interfaces for file operations. All methods without _sync suffix are async.
The class should be instantiated using one of the class methods. The constructor should be used only to
instantiate references to existing remote objects.
The generic type T represents the format of the file.
Important methods:
- `from_existing_remote`: Create a File object from an existing remote file.
- `new_remote`: Create a new File reference for a remote file that will be written to.
**Asynchronous methods**:
- `open`: Asynchronously open the file and return a file-like object.
- `download`: Asynchronously download the file to a local path.
- `from_local`: Asynchronously create a File object from a local file, uploading it to remote storage.
- `exists`: Asynchronously check if the file exists.
**Synchronous methods** (suffixed with `_sync`):
- `open_sync`: Synchronously open the file and return a file-like object.
- `download_sync`: Synchronously download the file to a local path.
- `from_local_sync`: Synchronously create a File object from a local file, uploading it to remote storage.
- `exists_sync`: Synchronously check if the file exists.
Example: Read a file input in a Task (Async).
```python
@env.task
async def read_file(file: File) -> str:
async with file.open("rb") as f:
content = bytes(await f.read())
return content.decode("utf-8")
```
Example: Read a file input in a Task (Sync).
```python
@env.task
def read_file_sync(file: File) -> str:
with file.open_sync("rb") as f:
content = f.read()
return content.decode("utf-8")
```
Example: Write a file by streaming it directly to blob storage (Async).
```python
@env.task
async def write_file() -> File:
file = File.new_remote()
async with file.open("wb") as f:
await f.write(b"Hello, World!")
return file
```
Example: Upload a local file to remote storage (Async).
```python
@env.task
async def upload_file() -> File:
# Write to local file first
with open("/tmp/data.csv", "w") as f:
f.write("col1,col2\n1,2\n3,4\n")
# Upload to remote storage
return await File.from_local("/tmp/data.csv")
```
Example: Upload a local file to remote storage (Sync).
```python
@env.task
def upload_file_sync() -> File:
# Write to local file first
with open("/tmp/data.csv", "w") as f:
f.write("col1,col2\n1,2\n3,4\n")
# Upload to remote storage
return File.from_local_sync("/tmp/data.csv")
```
Example: Download a file to local storage (Async).
```python
@env.task
async def download_file(file: File) -> str:
local_path = await file.download()
# Process the local file
with open(local_path, "r") as f:
return f.read()
```
Example: Download a file to local storage (Sync).
```python
@env.task
def download_file_sync(file: File) -> str:
local_path = file.download_sync()
# Process the local file
with open(local_path, "r") as f:
return f.read()
```
Example: Reference an existing remote file.
```python
@env.task
async def process_existing_file() -> str:
file = File.from_existing_remote("s3://my-bucket/data.csv")
async with file.open("rb") as f:
content = await f.read()
return content.decode("utf-8")
```
Example: Check if a file exists (Async).
```python
@env.task
async def check_file(file: File) -> bool:
return await file.exists()
```
Example: Check if a file exists (Sync).
```python
@env.task
def check_file_sync(file: File) -> bool:
return file.exists_sync()
```
Example: Pass through a file without copying.
```python
@env.task
async def pass_through(file: File) -> File:
# No copy occurs - just passes the reference
return file
```
## Parameters
```python
class File(
path: str,
name: typing.Optional[str],
format: str,
hash: typing.Optional[str],
hash_method: typing.Optional[flyte.io._hashing_io.HashMethod],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | The path to the file (can be local or remote) |
| `name` | `typing.Optional[str]` | Optional name for the file (defaults to basename of path) |
| `format` | `str` | |
| `hash` | `typing.Optional[str]` | |
| `hash_method` | `typing.Optional[flyte.io._hashing_io.HashMethod]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `lazy_uploader` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > File > Methods > download()** | Asynchronously download the file to a local path. |
| **Flyte SDK > Packages > flyte.io > File > Methods > download_sync()** | Synchronously download the file to a local path. |
| **Flyte SDK > Packages > flyte.io > File > Methods > exists()** | Asynchronously check if the file exists. |
| **Flyte SDK > Packages > flyte.io > File > Methods > exists_sync()** | Synchronously check if the file exists. |
| **Flyte SDK > Packages > flyte.io > File > Methods > from_existing_remote()** | Create a File reference from an existing remote file. |
| **Flyte SDK > Packages > flyte.io > File > Methods > from_local()** | Asynchronously create a new File object from a local file by uploading it to remote storage. |
| **Flyte SDK > Packages > flyte.io > File > Methods > from_local_sync()** | Synchronously create a new File object from a local file by uploading it to remote storage. |
| **Flyte SDK > Packages > flyte.io > File > Methods > model_post_init()** | This function is meant to behave like a BaseModel method to initialise private attributes. |
| **Flyte SDK > Packages > flyte.io > File > Methods > named_remote()** | Create a File reference whose remote path is derived deterministically from *name*. |
| **Flyte SDK > Packages > flyte.io > File > Methods > new_remote()** | Create a new File reference for a remote file that will be written to. |
| **Flyte SDK > Packages > flyte.io > File > Methods > open()** | Asynchronously open the file and return a file-like object. |
| **Flyte SDK > Packages > flyte.io > File > Methods > open_sync()** | Synchronously open the file and return a file-like object. |
| **Flyte SDK > Packages > flyte.io > File > Methods > pre_init()** | Internal: Pydantic validator to set default name from path. |
| **Flyte SDK > Packages > flyte.io > File > Methods > schema_match()** | Internal: Check if incoming schema matches File schema. |
### download()
```python
def download(
local_path: Optional[Union[str, Path]],
) -> str
```
Asynchronously download the file to a local path.
Use this when you need to download a remote file to your local filesystem for processing.
Example (Async):
```python
@env.task
async def download_and_process(f: File) -> str:
local_path = await f.download()
# Now process the local file
with open(local_path, "r") as fh:
return fh.read()
```
Example (Download to specific path):
```python
@env.task
async def download_to_path(f: File) -> str:
local_path = await f.download("/tmp/myfile.csv")
return local_path
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the file to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded file
### download_sync()
```python
def download_sync(
local_path: Optional[Union[str, Path]],
) -> str
```
Synchronously download the file to a local path.
Use this in non-async tasks when you need to download a remote file to your local filesystem.
Example (Sync):
```python
@env.task
def download_and_process_sync(f: File) -> str:
local_path = f.download_sync()
# Now process the local file
with open(local_path, "r") as fh:
return fh.read()
```
Example (Download to specific path):
```python
@env.task
def download_to_path_sync(f: File) -> str:
local_path = f.download_sync("/tmp/myfile.csv")
return local_path
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the file to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded file
### exists()
```python
def exists()
```
Asynchronously check if the file exists.
Example (Async):
```python
@env.task
async def check_file(f: File) -> bool:
if await f.exists():
print("File exists!")
return True
return False
```
**Returns:** True if the file exists, False otherwise
### exists_sync()
```python
def exists_sync()
```
Synchronously check if the file exists.
Use this in non-async tasks or when you need synchronous file existence checking.
Example (Sync):
```python
@env.task
def check_file_sync(f: File) -> bool:
if f.exists_sync():
print("File exists!")
return True
return False
```
**Returns:** True if the file exists, False otherwise
### from_existing_remote()
```python
def from_existing_remote(
remote_path: str,
file_cache_key: Optional[str],
) -> File[T]
```
Create a File reference from an existing remote file.
Use this when you want to reference a file that already exists in remote storage without uploading it.
Example:
```python
@env.task
async def process_existing_file() -> str:
file = File.from_existing_remote("s3://my-bucket/data.csv")
async with file.open("rb") as f:
content = await f.read()
return content.decode("utf-8")
```
| Parameter | Type | Description |
|-|-|-|
| `remote_path` | `str` | The remote path to the existing file |
| `file_cache_key` | `Optional[str]` | Optional hash value to use for cache key computation. If not specified, the cache key will be computed based on the file's attributes (path, name, format). |
**Returns:** A new File instance pointing to the existing remote file
### from_local()
```python
def from_local(
local_path: Union[str, Path],
remote_destination: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Asynchronously create a new File object from a local file by uploading it to remote storage.
Use this in async tasks when you have a local file that needs to be uploaded to remote storage.
Example (Async):
```python
@env.task
async def upload_local_file() -> File:
# Create a local file
async with aiofiles.open("/tmp/data.csv", "w") as f:
await f.write("col1,col2
# Upload to remote storage
remote_file = await File.from_local("/tmp/data.csv")
return remote_file
```
Example (With specific destination):
```python
@env.task
async def upload_to_specific_path() -> File:
remote_file = await File.from_local("/tmp/data.csv", "s3://my-bucket/data.csv")
return remote_file
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local file |
| `remote_destination` | `Optional[str]` | Optional remote path to store the file. If None, a path will be automatically generated. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash during upload. If not specified, the cache key will be based on file attributes. |
**Returns**
A new File instance pointing to the uploaded remote file
### from_local_sync()
```python
def from_local_sync(
local_path: Union[str, Path],
remote_destination: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Synchronously create a new File object from a local file by uploading it to remote storage.
Use this in non-async tasks when you have a local file that needs to be uploaded to remote storage.
Example (Sync):
```python
@env.task
def upload_local_file_sync() -> File:
# Create a local file
with open("/tmp/data.csv", "w") as f:
f.write("col1,col2
# Upload to remote storage
remote_file = File.from_local_sync("/tmp/data.csv")
return remote_file
```
Example (With specific destination):
```python
@env.task
def upload_to_specific_path() -> File:
remote_file = File.from_local_sync("/tmp/data.csv", "s3://my-bucket/data.csv")
return remote_file
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local file |
| `remote_destination` | `Optional[str]` | Optional remote path to store the file. If None, a path will be automatically generated. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash during upload. If not specified, the cache key will be based on file attributes. |
**Returns**
A new File instance pointing to the uploaded remote file
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | The context. |
### named_remote()
```python
def named_remote(
name: str,
) -> File[T]
```
Create a File reference whose remote path is derived deterministically from *name*.
Unlike `new_remote`, which generates a random path on every call, this method
produces the same path for the same *name* within a given task execution. This makes
it safe across retries: the first attempt uploads to the path and subsequent retries
resolve to the identical location without re-uploading.
The path is optionally namespaced by the node ID extracted from the backend
raw-data path, which follows the convention:
{run_name}-{node_id}-{attempt_index}
If extraction fails, the function falls back to the run base directory alone.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Plain filename (e.g., "data.csv"). Must not contain path separators. |
**Returns:** A `File` instance whose path is stable across retries.
### new_remote()
```python
def new_remote(
file_name: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Create a new File reference for a remote file that will be written to.
Use this when you want to create a new file and write to it directly without creating a local file first.
Example (Async):
```python
@env.task
async def create_csv() -> File:
df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
file = File.new_remote()
async with file.open("wb") as f:
df.to_csv(f)
return file
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `Optional[str]` | Optional string specifying a remote file name. If not set, a generated file name will be returned. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will be used to compute the hash as data is written. |
**Returns:** A new File instance with a generated remote path
### open()
```python
def open(
mode: str,
block_size: Optional[int],
cache_type: str,
cache_options: Optional[dict],
compression: Optional[str],
kwargs,
) -> AsyncGenerator[Union[AsyncWritableFile, AsyncReadableFile, 'HashingWriter'], None]
```
Asynchronously open the file and return a file-like object.
Use this method in async tasks to read from or write to files directly.
Example (Async Read):
```python
@env.task
async def read_file(f: File) -> str:
async with f.open("rb") as fh:
content = bytes(await fh.read())
return content.decode("utf-8")
```
Example (Async Write):
```python
@env.task
async def write_file() -> File:
f = File.new_remote()
async with f.open("wb") as fh:
await fh.write(b"Hello, World!")
return f
```
Example (Streaming Read):
```python
@env.task
async def stream_read(f: File) -> str:
content_parts = []
async with f.open("rb", block_size=1024) as fh:
while True:
chunk = await fh.read()
if not chunk:
break
content_parts.append(chunk)
return b"".join(content_parts).decode("utf-8")
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `str` | The mode to open the file in (default: 'rb'). Common modes: 'rb' (read binary), 'wb' (write binary), 'rt' (read text), 'wt' (write text) |
| `block_size` | `Optional[int]` | Size of blocks for reading in bytes. Useful for streaming large files. |
| `cache_type` | `str` | Caching mechanism to use ('readahead', 'mmap', 'bytes', 'none') |
| `cache_options` | `Optional[dict]` | Dictionary of options for the cache |
| `compression` | `Optional[str]` | Compression format or None for auto-detection |
| `kwargs` | `**kwargs` | |
**Returns:** An async file-like object that can be used with async read/write operations
### open_sync()
```python
def open_sync(
mode: str,
block_size: Optional[int],
cache_type: str,
cache_options: Optional[dict],
compression: Optional[str],
kwargs,
) -> Generator[IO[Any], None, None]
```
Synchronously open the file and return a file-like object.
Use this method in non-async tasks to read from or write to files directly.
Example (Sync Read):
```python
@env.task
def read_file_sync(f: File) -> str:
with f.open_sync("rb") as fh:
content = fh.read()
return content.decode("utf-8")
```
Example (Sync Write):
```python
@env.task
def write_file_sync() -> File:
f = File.new_remote()
with f.open_sync("wb") as fh:
fh.write(b"Hello, World!")
return f
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `str` | The mode to open the file in (default: 'rb'). Common modes: 'rb' (read binary), 'wb' (write binary), 'rt' (read text), 'wt' (write text) |
| `block_size` | `Optional[int]` | Size of blocks for reading in bytes. Useful for streaming large files. |
| `cache_type` | `str` | Caching mechanism to use ('readahead', 'mmap', 'bytes', 'none') |
| `cache_options` | `Optional[dict]` | Dictionary of options for the cache |
| `compression` | `Optional[str]` | Compression format or None for auto-detection |
| `kwargs` | `**kwargs` | |
**Returns:** A file-like object that can be used with standard read/write operations
### pre_init()
```python
def pre_init(
data,
)
```
Internal: Pydantic validator to set default name from path. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `data` | | |
### schema_match()
```python
def schema_match(
incoming: dict,
)
```
Internal: Check if incoming schema matches File schema. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `incoming` | `dict` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io/hashfunction ===
# HashFunction
**Package:** `flyte.io`
A hash method that wraps a user-provided function to compute hashes.
This class allows you to define custom hashing logic by providing a callable
that takes data and returns a hash string. It implements the HashMethod protocol,
making it compatible with Flyte's hashing infrastructure.
Example:
>>> def my_hash(data: bytes) -> str:
... return hashlib.md5(data).hexdigest()
>>> hash_fn = HashFunction.from_fn(my_hash)
>>> hash_fn.update(b"hello")
>>> hash_fn.result()
'5d41402abc4b2a76b9719d911017c592'
Attributes:
_fn: The callable that computes the hash from input data.
_value: The most recently computed hash value.
## Parameters
```python
class HashFunction(
fn: Callable[[Any], str],
)
```
Initialize a HashFunction with a custom hash callable.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[[Any], str]` | A callable that takes data of any type and returns a hash string. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > HashFunction > Methods > from_fn()** | Create a HashFunction from a callable. |
| **Flyte SDK > Packages > flyte.io > HashFunction > Methods > reset()** | |
| **Flyte SDK > Packages > flyte.io > HashFunction > Methods > result()** | Return the most recently computed hash value. |
| **Flyte SDK > Packages > flyte.io > HashFunction > Methods > update()** | Update the hash value by applying the hash function to the given data. |
### from_fn()
```python
def from_fn(
fn: Callable[[Any], str],
) -> HashFunction
```
Create a HashFunction from a callable.
This is a convenience factory method for creating HashFunction instances.
Example:
>>> hash_fn = HashFunction.from_fn(lambda x: hashlib.sha256(x).hexdigest())
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[[Any], str]` | A callable that takes data of any type and returns a hash string. |
**Returns**
A new HashFunction instance wrapping the provided callable.
### reset()
```python
def reset()
```
### result()
```python
def result()
```
Return the most recently computed hash value.
**Returns:** The hash string from the last call to update().
### update()
```python
def update(
data: Any,
)
```
Update the hash value by applying the hash function to the given data.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | The data to hash. The type depends on the hash function provided. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io.extend ===
# flyte.io.extend
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io.extend > DataFrameDecoder** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameEncoder** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine** | Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
## Subpages
- **Flyte SDK > Packages > flyte.io.extend > DataFrameDecoder**
- **Flyte SDK > Packages > flyte.io.extend > DataFrameEncoder**
- **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io.extend/dataframedecoder ===
# DataFrameDecoder
**Package:** `flyte.io.extend`
## Parameters
```python
class DataFrameDecoder(
python_type: Type[DF],
protocol: Optional[str],
supported_format: Optional[str],
additional_protocols: Optional[List[str]],
)
```
Extend this abstract class, implement the decode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
and we have to get a Python value out of it. For the other way, see the DataFrameEncoder
| Parameter | Type | Description |
|-|-|-|
| `python_type` | `Type[DF]` | The dataframe class in question that you want to register this decoder with |
| `protocol` | `Optional[str]` | A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. If None, this decoder will be registered with all protocols that flytekit's data persistence layer is capable of handling. |
| `supported_format` | `Optional[str]` | Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the decoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" decoder instead and write a warning. |
| `additional_protocols` | `Optional[List[str]]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io.extend > DataFrameDecoder > Methods > decode()** | This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal. |
### decode()
```python
def decode(
flyte_value: literals_pb2.StructuredDataset,
current_task_metadata: literals_pb2.StructuredDatasetMetadata,
) -> Union[DF, typing.AsyncIterator[DF]]
```
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.
of those dataframes.
| Parameter | Type | Description |
|-|-|-|
| `flyte_value` | `literals_pb2.StructuredDataset` | This will be a Flyte IDL DataFrame Literal - do not confuse this with the DataFrame class defined also in this module. |
| `current_task_metadata` | `literals_pb2.StructuredDatasetMetadata` | Metadata object containing the type (and columns if any) for the currently executing task. This type may have more or less information than the type information bundled inside the incoming flyte_value. |
**Returns:** This function can either return an instance of the dataframe that this decoder handles, or an iterator
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io.extend/dataframeencoder ===
# DataFrameEncoder
**Package:** `flyte.io.extend`
## Parameters
```python
class DataFrameEncoder(
python_type: Type[T],
protocol: Optional[str],
supported_format: Optional[str],
)
```
Extend this abstract class, implement the encode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
the DataFrameEncoder
| Parameter | Type | Description |
|-|-|-|
| `python_type` | `Type[T]` | The dataframe class in question that you want to register this encoder with |
| `protocol` | `Optional[str]` | A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. If None, this encoder will be registered with all protocols that flytekit's data persistence layer is capable of handling. |
| `supported_format` | `Optional[str]` | Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the encoder works with any format. If the format being asked for does not exist, the transformer engine will look for the "" encoder instead and write a warning. |
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io.extend > DataFrameEncoder > Methods > encode()** | Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the. |
### encode()
```python
def encode(
dataframe: DataFrame,
structured_dataset_type: types_pb2.StructuredDatasetType,
) -> literals_pb2.StructuredDataset
```
Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the
incoming dataframe with defaults set for that dataframe
type. This simplifies this function's interface as a lot of data that could be specified by the user using
the
# TODO: Do we need to add a flag to indicate if it was wrapped by the transformer or by the user?
DataFrame wrapper class used as input to this function - that is the user facing Python class.
This function needs to return the IDL DataFrame.
| Parameter | Type | Description |
|-|-|-|
| `dataframe` | `DataFrame` | This is a DataFrame wrapper object. See more info above. |
| `structured_dataset_type` | `types_pb2.StructuredDatasetType` | This the DataFrameType, as found in the LiteralType of the interface of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders can include it in the returned literals.DataFrame. See the IDL for more information on why this literal in particular carries the type information along with it. If the encoder doesn't supply it, it will also be filled in after the encoder runs by the transformer engine. |
**Returns:** This function should return a DataFrame literal object. Do not confuse this with the
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io.extend/dataframetransformerengine ===
# DataFrameTransformerEngine
**Package:** `flyte.io.extend`
Think of this transformer as a higher-level meta transformer that is used for all the dataframe types.
If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of
registering with the main type engine, you should register with this transformer instead.
## Parameters
```python
def DataFrameTransformerEngine()
```
## Properties
| Property | Type | Description |
|-|-|-|
| `name` | `None` | |
| `python_type` | `None` | This returns the python type |
| `type_assertions_enabled` | `None` | Indicates if the transformer wants type assertions to be enabled at the core type engine layer |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > assert_type()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > encode()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > from_binary_idl()** | This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > get_decoder()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > get_encoder()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > get_literal_type()** | Provide a concrete implementation so that writers of custom dataframe handlers since there's nothing that. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > get_structured_dataset_type()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > guess_python_type()** | Converts the Flyte LiteralType to a python object type. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > isinstance_generic()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > iter_as()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > open_as()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > register()** | Call this with any Encoder or Decoder to register it with the flytekit type system. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > register_for_protocol()** | See the main register function instead. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > register_renderer()** | |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > schema_match()** | Check if a JSON schema fragment matches this transformer's python_type. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > to_html()** | Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > to_literal()** | Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type. |
| **Flyte SDK > Packages > flyte.io.extend > DataFrameTransformerEngine > Methods > to_python_value()** | The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at. |
### assert_type()
```python
def assert_type(
t: Type[DataFrame],
v: typing.Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `t` | `Type[DataFrame]` | |
| `v` | `typing.Any` | |
### encode()
```python
def encode(
df: DataFrame,
df_type: Type,
protocol: str,
format: str,
structured_literal_type: types_pb2.StructuredDatasetType,
) -> literals_pb2.Literal
```
| Parameter | Type | Description |
|-|-|-|
| `df` | `DataFrame` | |
| `df_type` | `Type` | |
| `protocol` | `str` | |
| `format` | `str` | |
| `structured_literal_type` | `types_pb2.StructuredDatasetType` | |
### from_binary_idl()
```python
def from_binary_idl(
binary_idl_object: Binary,
expected_python_type: Type[T],
) -> Optional[T]
```
This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and
attribute access.
For untyped dict, dataclass, and pydantic basemodel:
Life Cycle (Untyped Dict as example):
python val -> msgpack bytes -> binary literal scalar -> msgpack bytes -> python val
(to_literal) (from_binary_idl)
For attribute access:
Life Cycle:
python val -> msgpack bytes -> binary literal scalar -> resolved golang value -> binary literal scalar
-> msgpack bytes -> python val
(to_literal) (propeller attribute access) (from_binary_idl)
| Parameter | Type | Description |
|-|-|-|
| `binary_idl_object` | `Binary` | |
| `expected_python_type` | `Type[T]` | |
### get_decoder()
```python
def get_decoder(
df_type: Type,
protocol: str,
format: str,
) -> DataFrameDecoder
```
| Parameter | Type | Description |
|-|-|-|
| `df_type` | `Type` | |
| `protocol` | `str` | |
| `format` | `str` | |
### get_encoder()
```python
def get_encoder(
df_type: Type,
protocol: str,
format: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `df_type` | `Type` | |
| `protocol` | `str` | |
| `format` | `str` | |
### get_literal_type()
```python
def get_literal_type(
t: typing.Union[Type[DataFrame], typing.Any],
) -> types_pb2.LiteralType
```
Provide a concrete implementation so that writers of custom dataframe handlers since there's nothing that
special about the literal type. Any dataframe type will always be associated with the structured dataset type.
The other aspects of it - columns, external schema type, etc. can be read from associated metadata.
| Parameter | Type | Description |
|-|-|-|
| `t` | `typing.Union[Type[DataFrame], typing.Any]` | The python dataframe type, which is mostly ignored. |
### get_structured_dataset_type()
```python
def get_structured_dataset_type(
storage_format: str | None,
pa_schema: Optional['pa.lib.Schema'],
column_map: typing.OrderedDict[str, type[typing.Any]] | None,
) -> types_pb2.StructuredDatasetType
```
| Parameter | Type | Description |
|-|-|-|
| `storage_format` | `str \| None` | |
| `pa_schema` | `Optional['pa.lib.Schema']` | |
| `column_map` | `typing.OrderedDict[str, type[typing.Any]] \| None` | |
### guess_python_type()
```python
def guess_python_type(
literal_type: types_pb2.LiteralType,
) -> Type[DataFrame]
```
Converts the Flyte LiteralType to a python object type.
| Parameter | Type | Description |
|-|-|-|
| `literal_type` | `types_pb2.LiteralType` | |
### isinstance_generic()
```python
def isinstance_generic(
obj,
generic_alias,
)
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | | |
| `generic_alias` | | |
### iter_as()
```python
def iter_as(
sd: literals_pb2.StructuredDataset,
df_type: Type[DF],
updated_metadata: literals_pb2.StructuredDatasetMetadata,
) -> typing.AsyncIterator[DF]
```
| Parameter | Type | Description |
|-|-|-|
| `sd` | `literals_pb2.StructuredDataset` | |
| `df_type` | `Type[DF]` | |
| `updated_metadata` | `literals_pb2.StructuredDatasetMetadata` | |
### open_as()
```python
def open_as(
sd: literals_pb2.StructuredDataset,
df_type: Type[DF],
updated_metadata: literals_pb2.StructuredDatasetMetadata,
) -> DF
```
| Parameter | Type | Description |
|-|-|-|
| `sd` | `literals_pb2.StructuredDataset` | |
| `df_type` | `Type[DF]` | |
| `updated_metadata` | `literals_pb2.StructuredDatasetMetadata` | New metadata type, since it might be different from the metadata in the literal. |
**Returns:** dataframe. It could be pandas dataframe or arrow table, etc.
### register()
```python
def register(
h: Handlers,
default_for_type: bool,
override: bool,
default_format_for_type: bool,
default_storage_for_type: bool,
)
```
Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
specify a protocol (e.g. s3, gs, etc.) field, then
| Parameter | Type | Description |
|-|-|-|
| `h` | `Handlers` | The DataFrameEncoder or DataFrameDecoder you wish to register with this transformer. |
| `default_for_type` | `bool` | If set, when a user returns from a task an instance of the dataframe the handler handles, e.g. `return pd.DataFrame(...)`, not wrapped around the `StructuredDataset` object, we will use this handler's protocol and format as the default, effectively saying that this handler will be called. Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of. In these cases, the protocol is determined by the raw output data prefix set in the active context. |
| `override` | `bool` | Override any previous registrations. If default_for_type is also set, this will also override the default. |
| `default_format_for_type` | `bool` | Unlike the default_for_type arg that will set this handler's format and storage as the default, this will only set the format. Error if already set, unless override is specified. |
| `default_storage_for_type` | `bool` | Same as above but only for the storage format. Error if already set, unless override is specified. |
### register_for_protocol()
```python
def register_for_protocol(
h: Handlers,
protocol: str,
default_for_type: bool,
override: bool,
default_format_for_type: bool,
default_storage_for_type: bool,
)
```
See the main register function instead.
| Parameter | Type | Description |
|-|-|-|
| `h` | `Handlers` | |
| `protocol` | `str` | |
| `default_for_type` | `bool` | |
| `override` | `bool` | |
| `default_format_for_type` | `bool` | |
| `default_storage_for_type` | `bool` | |
### register_renderer()
```python
def register_renderer(
python_type: Type,
renderer: Renderable,
)
```
| Parameter | Type | Description |
|-|-|-|
| `python_type` | `Type` | |
| `renderer` | `Renderable` | |
### schema_match()
```python
def schema_match(
schema: dict,
) -> bool
```
Check if a JSON schema fragment matches this transformer's python_type.
For BaseModel subclasses, automatically compares the schema's title, type, and
required fields against the type's own JSON schema. For other types, returns
False by default β override if needed.
| Parameter | Type | Description |
|-|-|-|
| `schema` | `dict` | |
### to_html()
```python
def to_html(
python_val: typing.Any,
expected_python_type: Type[T],
) -> str
```
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `typing.Any` | |
| `expected_python_type` | `Type[T]` | |
### to_literal()
```python
def to_literal(
python_val: Union[DataFrame, typing.Any],
python_type: Union[Type[DataFrame], Type],
expected: types_pb2.LiteralType,
) -> literals_pb2.Literal
```
Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type.
Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these
do not match (or are not allowed) the Transformer implementer should raise an AssertionError, clearly stating
what was the mismatch
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `Union[DataFrame, typing.Any]` | The actual value to be transformed |
| `python_type` | `Union[Type[DataFrame], Type]` | The assumed type of the value (this matches the declared type on the function) |
| `expected` | `types_pb2.LiteralType` | Expected Literal Type |
### to_python_value()
```python
def to_python_value(
lv: literals_pb2.Literal,
expected_python_type: Type[T] | DataFrame,
) -> T | DataFrame
```
The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at
the start of a task execution, is the column subsetting behavior. For example, if you have,
def t1() -> Annotated[StructuredDataset, kwtypes(col_a=int, col_b=float)]: ...
def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
where t2(in_a=t1()), when t2 does in_a.open(pd.DataFrame).all(), it should get a DataFrame
with only one column.
+-----------------------------+-----------------------------------------+--------------------------------------+
| | StructuredDatasetType of the incoming Literal |
+-----------------------------+-----------------------------------------+--------------------------------------+
| StructuredDatasetType | Has columns defined | [] columns or None |
| of currently running task | | |
+=============================+=========================================+======================================+
| Has columns | The StructuredDatasetType passed to the decoder will have the columns |
| defined | as defined by the type annotation of the currently running task. |
| | |
| | Decoders **should** then subset the incoming data to the columns requested. |
| | |
+-----------------------------+-----------------------------------------+--------------------------------------+
| [] columns or None | StructuredDatasetType passed to decoder | StructuredDatasetType passed to the |
| | will have the columns from the incoming | decoder will have an empty list of |
| | Literal. This is the scenario where | columns. |
| | the Literal returned by the running | |
| | task will have more information than | |
| | the running task's signature. | |
+-----------------------------+-----------------------------------------+--------------------------------------+
| Parameter | Type | Description |
|-|-|-|
| `lv` | `literals_pb2.Literal` | |
| `expected_python_type` | `Type[T] \| DataFrame` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models ===
# flyte.models
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > ActionID** | A class representing the ID of an Action, nested within a Run. |
| **Flyte SDK > Packages > flyte.models > ActionPhase** | Represents the execution phase of a Flyte action (run). |
| **Flyte SDK > Packages > flyte.models > Checkpoints** | A class representing the checkpoints for a task. |
| **Flyte SDK > Packages > flyte.models > CodeBundle** | A class representing a code bundle for a task. |
| **Flyte SDK > Packages > flyte.models > GroupData** | |
| **Flyte SDK > Packages > flyte.models > NativeInterface** | A class representing the native interface for a task. |
| **Flyte SDK > Packages > flyte.models > PathRewrite** | Configuration for rewriting paths during input loading. |
| **Flyte SDK > Packages > flyte.models > RawDataPath** | A class representing the raw data path for a task. |
| **Flyte SDK > Packages > flyte.models > SerializationContext** | This object holds serialization time contextual information, that can be used when serializing the task and. |
| **Flyte SDK > Packages > flyte.models > TaskContext** | A context class to hold the current task executions context. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > Methods > generate_random_name()** | Generate a random name for the task. |
### Variables
| Property | Type | Description |
|-|-|-|
| `MAX_INLINE_IO_BYTES` | `int` | |
| `TYPE_CHECKING` | `bool` | |
## Methods
#### generate_random_name()
```python
def generate_random_name()
```
Generate a random name for the task. This is used to create unique names for tasks.
TODO we can use unique-namer in the future, for now its just guids
## Subpages
- **Flyte SDK > Packages > flyte.models > ActionID**
- **Flyte SDK > Packages > flyte.models > ActionPhase**
- **Flyte SDK > Packages > flyte.models > Checkpoints**
- **Flyte SDK > Packages > flyte.models > CodeBundle**
- **Flyte SDK > Packages > flyte.models > GroupData**
- **Flyte SDK > Packages > flyte.models > NativeInterface**
- **Flyte SDK > Packages > flyte.models > PathRewrite**
- **Flyte SDK > Packages > flyte.models > RawDataPath**
- **Flyte SDK > Packages > flyte.models > SerializationContext**
- **Flyte SDK > Packages > flyte.models > TaskContext**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/actionid ===
# ActionID
**Package:** `flyte.models`
A class representing the ID of an Action, nested within a Run. This is used to identify a specific action on a task.
## Parameters
```python
class ActionID(
name: str,
run_name: str | None,
project: str | None,
domain: str | None,
org: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `run_name` | `str \| None` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `org` | `str \| None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > ActionID > Methods > create_random()** | |
| **Flyte SDK > Packages > flyte.models > ActionID > Methods > new_sub_action()** | Create a new sub-run with the given name. |
| **Flyte SDK > Packages > flyte.models > ActionID > Methods > new_sub_action_from()** | Make a deterministic name. |
| **Flyte SDK > Packages > flyte.models > ActionID > Methods > unique_id_str()** | Generate a unique ID string for this action in the format:. |
### create_random()
```python
def create_random()
```
### new_sub_action()
```python
def new_sub_action(
name: str | None,
) -> ActionID
```
Create a new sub-run with the given name. If name is None, a random name will be generated.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str \| None` | |
### new_sub_action_from()
```python
def new_sub_action_from(
task_call_seq: int,
task_hash: str,
input_hash: str,
group: str | None,
) -> ActionID
```
Make a deterministic name
| Parameter | Type | Description |
|-|-|-|
| `task_call_seq` | `int` | |
| `task_hash` | `str` | |
| `input_hash` | `str` | |
| `group` | `str \| None` | |
### unique_id_str()
```python
def unique_id_str(
salt: str | None,
) -> str
```
Generate a unique ID string for this action in the format:
{project}-{domain}-{run_name}-{action_name}
This is optimized for performance assuming all fields are available.
| Parameter | Type | Description |
|-|-|-|
| `salt` | `str \| None` | |
**Returns:** A unique ID string
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/actionphase ===
# ActionPhase
**Package:** `flyte.models`
Represents the execution phase of a Flyte action (run).
Actions progress through different phases during their lifecycle:
- Queued: Action is waiting to be scheduled
- Waiting for resources: Action is waiting for compute resources
- Initializing: Action is being initialized
- Running: Action is currently executing
- Succeeded: Action completed successfully
- Failed: Action failed during execution
- Aborted: Action was manually aborted
- Timed out: Action exceeded its timeout limit
This enum can be used for filtering runs and checking execution status.
Example:
>>> from flyte.models import ActionPhase
>>> from flyte.remote import Run
>>>
>>> # Filter runs by phase
>>> runs = Run.listall(in_phase=(ActionPhase.SUCCEEDED, ActionPhase.FAILED))
>>>
>>> # Check if a run succeeded
>>> run = Run.get("my-run")
>>> if run.phase == ActionPhase.SUCCEEDED:
... print("Success!")
>>>
>>> # Check if phase is terminal
>>> if run.phase.is_terminal:
... print("Run completed")
## Parameters
```python
class ActionPhase(
args,
kwds,
)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwds` | | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/checkpoints ===
# Checkpoints
**Package:** `flyte.models`
A class representing the checkpoints for a task. This is used to store the checkpoints for the task execution.
## Parameters
```python
class Checkpoints(
prev_checkpoint_path: str | None,
checkpoint_path: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `prev_checkpoint_path` | `str \| None` | |
| `checkpoint_path` | `str \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/codebundle ===
# CodeBundle
**Package:** `flyte.models`
A class representing a code bundle for a task. This is used to package the code and the inflation path.
The code bundle computes the version of the code using the hash of the code.
## Parameters
```python
class CodeBundle(
computed_version: str,
destination: str,
tgz: str | None,
pkl: str | None,
downloaded_path: pathlib.Path | None,
files: List[str] | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `computed_version` | `str` | The version of the code bundle. This is the hash of the code. |
| `destination` | `str` | The destination path for the code bundle to be inflated to. |
| `tgz` | `str \| None` | Optional path to the tgz file. |
| `pkl` | `str \| None` | Optional path to the pkl file. |
| `downloaded_path` | `pathlib.Path \| None` | The path to the downloaded code bundle. This is only available during runtime, when the code bundle has been downloaded and inflated. |
| `files` | `List[str] \| None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > CodeBundle > Methods > with_downloaded_path()** | Create a new CodeBundle with the given downloaded path. |
### with_downloaded_path()
```python
def with_downloaded_path(
path: pathlib.Path,
) -> CodeBundle
```
Create a new CodeBundle with the given downloaded path.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib.Path` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/groupdata ===
# GroupData
**Package:** `flyte.models`
## Parameters
```python
class GroupData(
name: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/nativeinterface ===
# NativeInterface
**Package:** `flyte.models`
A class representing the native interface for a task. This is used to interact with the task and its execution
context.
## Parameters
```python
class NativeInterface(
inputs: Dict[str, Tuple[Type, Any]],
outputs: Dict[str, Type],
docstring: Optional[Docstring],
_remote_defaults: Optional[Dict[str, literals_pb2.Literal]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `inputs` | `Dict[str, Tuple[Type, Any]]` | |
| `outputs` | `Dict[str, Type]` | |
| `docstring` | `Optional[Docstring]` | |
| `_remote_defaults` | `Optional[Dict[str, literals_pb2.Literal]]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `json_schema` | `None` | Convert task inputs to a JSON schema dict. Uses the Flyte type engine to produce a LiteralType for each input, then converts to JSON schema. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > convert_to_kwargs()** | Convert the given arguments to keyword arguments based on the native interface. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > from_callable()** | Extract the native interface from the given function. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > from_types()** | Create a new NativeInterface from the given types. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > get_input_types()** | Get the input types for the task. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > has_outputs()** | Check if the task has outputs. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > num_required_inputs()** | Get the number of required inputs for the task. |
| **Flyte SDK > Packages > flyte.models > NativeInterface > Methods > required_inputs()** | Get the names of the required inputs for the task. |
### convert_to_kwargs()
```python
def convert_to_kwargs(
args,
kwargs,
) -> Dict[str, Any]
```
Convert the given arguments to keyword arguments based on the native interface. This is used to convert the
arguments to the correct types for the task execution.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### from_callable()
```python
def from_callable(
func: Callable,
) -> NativeInterface
```
Extract the native interface from the given function. This is used to create a native interface for the task.
| Parameter | Type | Description |
|-|-|-|
| `func` | `Callable` | |
### from_types()
```python
def from_types(
inputs: Dict[str, Tuple[Type, Type[_has_default] | Type[inspect._empty]]],
outputs: Dict[str, Type],
default_inputs: Optional[Dict[str, literals_pb2.Literal]],
) -> NativeInterface
```
Create a new NativeInterface from the given types. This is used to create a native interface for the task.
| Parameter | Type | Description |
|-|-|-|
| `inputs` | `Dict[str, Tuple[Type, Type[_has_default] \| Type[inspect._empty]]]` | A dictionary of input names and their types and a value indicating if they have a default value. |
| `outputs` | `Dict[str, Type]` | A dictionary of output names and their types. |
| `default_inputs` | `Optional[Dict[str, literals_pb2.Literal]]` | Optional dictionary of default inputs for remote tasks. |
**Returns:** A NativeInterface object with the given inputs and outputs.
### get_input_types()
```python
def get_input_types()
```
Get the input types for the task. This is used to get the types of the inputs for the task execution.
### has_outputs()
```python
def has_outputs()
```
Check if the task has outputs. This is used to determine if the task has outputs or not.
### num_required_inputs()
```python
def num_required_inputs()
```
Get the number of required inputs for the task. This is used to determine how many inputs are required for the
task execution.
### required_inputs()
```python
def required_inputs()
```
Get the names of the required inputs for the task. This is used to determine which inputs are required for the
task execution.
**Returns:** A list of required input names.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/pathrewrite ===
# PathRewrite
**Package:** `flyte.models`
Configuration for rewriting paths during input loading.
## Parameters
```python
class PathRewrite(
old_prefix: str,
new_prefix: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `old_prefix` | `str` | |
| `new_prefix` | `str` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > PathRewrite > Methods > from_str()** | Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`. |
### from_str()
```python
def from_str(
pattern: str,
) -> PathRewrite
```
Create a PathRewrite from a string pattern of the form `old_prefix->new_prefix`.
| Parameter | Type | Description |
|-|-|-|
| `pattern` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/rawdatapath ===
# RawDataPath
**Package:** `flyte.models`
A class representing the raw data path for a task. This is used to store the raw data for the task execution and
also get mutations on the path.
## Parameters
```python
class RawDataPath(
path: str,
path_rewrite: Optional[PathRewrite],
)
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `path_rewrite` | `Optional[PathRewrite]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > RawDataPath > Methods > from_local_folder()** | Create a new context attribute object, with local path given. |
| **Flyte SDK > Packages > flyte.models > RawDataPath > Methods > get_random_remote_path()** | Returns a random path for uploading a file/directory to. |
### from_local_folder()
```python
def from_local_folder(
local_folder: str | pathlib.Path | None,
) -> RawDataPath
```
Create a new context attribute object, with local path given. Will be created if it doesn't exist.
| Parameter | Type | Description |
|-|-|-|
| `local_folder` | `str \| pathlib.Path \| None` | |
**Returns:** Path to the temporary directory
### get_random_remote_path()
```python
def get_random_remote_path(
file_name: Optional[str],
) -> str
```
Returns a random path for uploading a file/directory to. This file/folder will not be created, it's just a path.
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `Optional[str]` | If given, will be joined after a randomly generated portion. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/serializationcontext ===
# SerializationContext
**Package:** `flyte.models`
This object holds serialization time contextual information, that can be used when serializing the task and
various parameters of a tasktemplate. This is only available when the task is being serialized and can be
during a deployment or runtime.
## Parameters
```python
class SerializationContext(
version: str,
project: str | None,
domain: str | None,
org: str | None,
code_bundle: Optional[CodeBundle],
input_path: str,
output_path: str,
interpreter_path: str,
image_cache: ImageCache | None,
root_dir: Optional[pathlib.Path],
)
```
| Parameter | Type | Description |
|-|-|-|
| `version` | `str` | The version of the task |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `org` | `str \| None` | |
| `code_bundle` | `Optional[CodeBundle]` | The code bundle for the task. This is used to package the code and the inflation path. |
| `input_path` | `str` | The path to the inputs for the task. This is used to determine where the inputs will be located |
| `output_path` | `str` | The path to the outputs for the task. This is used to determine where the outputs will be located |
| `interpreter_path` | `str` | |
| `image_cache` | `ImageCache \| None` | |
| `root_dir` | `Optional[pathlib.Path]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > SerializationContext > Methods > get_entrypoint_path()** | Get the entrypoint path for the task. |
### get_entrypoint_path()
```python
def get_entrypoint_path(
interpreter_path: Optional[str],
) -> str
```
Get the entrypoint path for the task. This is used to determine the entrypoint for the task execution.
| Parameter | Type | Description |
|-|-|-|
| `interpreter_path` | `Optional[str]` | The path to the interpreter (python) |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/taskcontext ===
# TaskContext
**Package:** `flyte.models`
A context class to hold the current task executions context.
This can be used to access various contextual parameters in the task execution by the user.
## Parameters
```python
class TaskContext(
action: ActionID,
version: str,
raw_data_path: RawDataPath,
input_path: str | None,
output_path: str,
run_base_dir: str,
report: Report,
group_data: GroupData | None,
checkpoints: Checkpoints | None,
code_bundle: CodeBundle | None,
compiled_image_cache: ImageCache | None,
data: Dict[str, Any],
mode: Literal['local', 'remote', 'hybrid'],
interactive_mode: bool,
custom_context: Dict[str, str],
disable_run_cache: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `action` | `ActionID` | The action ID of the current execution. This is always set, within a run. |
| `version` | `str` | The version of the executed task. This is set when the task is executed by an action and will be set on all sub-actions. |
| `raw_data_path` | `RawDataPath` | |
| `input_path` | `str \| None` | |
| `output_path` | `str` | |
| `run_base_dir` | `str` | |
| `report` | `Report` | |
| `group_data` | `GroupData \| None` | |
| `checkpoints` | `Checkpoints \| None` | |
| `code_bundle` | `CodeBundle \| None` | |
| `compiled_image_cache` | `ImageCache \| None` | |
| `data` | `Dict[str, Any]` | |
| `mode` | `Literal['local', 'remote', 'hybrid']` | |
| `interactive_mode` | `bool` | |
| `custom_context` | `Dict[str, str]` | Context metadata for the action. If an action receives context, it'll automatically pass it to any actions it spawns. Context will not be used for cache key computation. |
| `disable_run_cache` | `bool` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > TaskContext > Methods > is_in_cluster()** | Check if the task is running in a cluster. |
| **Flyte SDK > Packages > flyte.models > TaskContext > Methods > replace()** | |
### is_in_cluster()
```python
def is_in_cluster()
```
Check if the task is running in a cluster.
**Returns:** bool
### replace()
```python
def replace(
kwargs,
) -> TaskContext
```
| Parameter | Type | Description |
|-|-|-|
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch ===
# flyte.prefetch
Prefetch utilities for Flyte.
This module provides functionality to prefetch various artifacts from remote registries,
such as HuggingFace models.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo** | Information about a HuggingFace model to store. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig** | Configuration for model sharding. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo** | Information about a stored model. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs** | Arguments for sharding a model using vLLM. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > Methods > hf_model()** | Store a HuggingFace model to remote storage. |
## Methods
#### hf_model()
```python
def hf_model(
repo: str,
raw_data_path: str | None,
artifact_name: str | None,
architecture: str | None,
task: str,
modality: tuple[str, ...],
serial_format: str | None,
model_type: str | None,
short_description: str | None,
shard_config: ShardConfig | None,
hf_token_key: str,
resources: Resources,
force: int,
) -> Run
```
Store a HuggingFace model to remote storage.
This function downloads a model from the HuggingFace Hub and prefetches it to
remote storage. It supports optional sharding using vLLM for large models.
The prefetch behavior follows this priority:
1. If the model isn't being sharded, stream files directly to remote storage.
2. If streaming fails, fall back to downloading a snapshot and uploading.
3. If sharding is configured, download locally, shard with vLLM, then upload.
Example usage:
```python
import flyte
flyte.init(endpoint="my-flyte-endpoint")
# Store a model without sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN",
)
run.wait()
# Prefetch and shard a model
from flyte.prefetch import ShardConfig, VLLMShardArgs
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=8),
),
accelerator="A100:8",
hf_token_key="HF_TOKEN",
)
run.wait()
```
| Parameter | Type | Description |
|-|-|-|
| `repo` | `str` | The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf'). |
| `raw_data_path` | `str \| None` | |
| `artifact_name` | `str \| None` | Optional name for the stored artifact. If not provided, the repo name will be used (with '.' replaced by '-'). |
| `architecture` | `str \| None` | Model architecture from HuggingFace config.json. |
| `task` | `str` | Model task (e.g., 'generate', 'classify', 'embed'). Default |
| `modality` | `tuple[str, ...]` | Modalities supported by the model. Default |
| `serial_format` | `str \| None` | Model serialization format (e.g., 'safetensors', 'onnx'). |
| `model_type` | `str \| None` | Model type (e.g., 'transformer', 'custom'). |
| `short_description` | `str \| None` | Short description of the model. |
| `shard_config` | `ShardConfig \| None` | Optional configuration for model sharding with vLLM. |
| `hf_token_key` | `str` | Name of the secret containing the HuggingFace token. Default |
| `resources` | `Resources` | |
| `force` | `int` | Force re-prefetch. Increment to force a new prefetch. Default |
**Returns:** A Run object representing the prefetch task execution.
## Subpages
- **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo**
- **Flyte SDK > Packages > flyte.prefetch > ShardConfig**
- **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo**
- **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/huggingfacemodelinfo ===
# HuggingFaceModelInfo
**Package:** `flyte.prefetch`
Information about a HuggingFace model to store.
## Parameters
```python
class HuggingFaceModelInfo(
repo: str,
artifact_name: str | None,
architecture: str | None,
task: str,
modality: tuple[str, ...],
serial_format: str | None,
model_type: str | None,
short_description: str | None,
shard_config: flyte.prefetch._hf_model.ShardConfig | None,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `repo` | `str` | The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf'). |
| `artifact_name` | `str \| None` | Optional name for the stored artifact. If not provided, the repo name will be used (with '.' replaced by '-'). |
| `architecture` | `str \| None` | Model architecture from HuggingFace config.json. |
| `task` | `str` | Model task (e.g., 'generate', 'classify', 'embed'). |
| `modality` | `tuple[str, ...]` | Modalities supported by the model (e.g., 'text', 'image'). |
| `serial_format` | `str \| None` | Model serialization format (e.g., 'safetensors', 'onnx'). |
| `model_type` | `str \| None` | Model type (e.g., 'transformer', 'custom'). |
| `short_description` | `str \| None` | Short description of the model. |
| `shard_config` | `flyte.prefetch._hf_model.ShardConfig \| None` | Optional configuration for model sharding. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/shardconfig ===
# ShardConfig
**Package:** `flyte.prefetch`
Configuration for model sharding.
## Parameters
```python
class ShardConfig(
engine: typing.Literal['vllm'],
args: *args,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `engine` | `typing.Literal['vllm']` | The sharding engine to use (currently only "vllm" is supported). |
| `args` | `*args` | Arguments for the sharding engine. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/storedmodelinfo ===
# StoredModelInfo
**Package:** `flyte.prefetch`
Information about a stored model.
## Parameters
```python
class StoredModelInfo(
artifact_name: str,
path: str,
metadata: dict[str, str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `artifact_name` | `str` | Name of the stored artifact. |
| `path` | `str` | Path to the stored model directory. |
| `metadata` | `dict[str, str]` | Metadata about the stored model. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/vllmshardargs ===
# VLLMShardArgs
**Package:** `flyte.prefetch`
Arguments for sharding a model using vLLM.
## Parameters
```python
class VLLMShardArgs(
tensor_parallel_size: int,
dtype: str,
trust_remote_code: bool,
max_model_len: int | None,
file_pattern: str | None,
max_file_size: int,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `tensor_parallel_size` | `int` | Number of tensor parallel workers. |
| `dtype` | `str` | Data type for model weights. |
| `trust_remote_code` | `bool` | Whether to trust remote code from HuggingFace. |
| `max_model_len` | `int \| None` | Maximum model context length. |
| `file_pattern` | `str \| None` | Pattern for sharded weight files. |
| `max_file_size` | `int` | Maximum size for each sharded file. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > get_vllm_args()** | Get arguments dict for vLLM LLM constructor. |
### get_vllm_args()
```python
def get_vllm_args(
model_path: str,
) -> dict[str, Any]
```
Get arguments dict for vLLM LLM constructor.
| Parameter | Type | Description |
|-|-|-|
| `model_path` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote ===
# flyte.remote
Remote Entities that are accessible from the Union Server once deployed or created.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Action** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs** | A class representing the inputs of an action. |
| **Flyte SDK > Packages > flyte.remote > ActionOutputs** | A class representing the outputs of an action. |
| **Flyte SDK > Packages > flyte.remote > App** | |
| **Flyte SDK > Packages > flyte.remote > Project** | A class representing a project in the Union API. |
| **Flyte SDK > Packages > flyte.remote > Run** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > RunDetails** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > Secret** | |
| **Flyte SDK > Packages > flyte.remote > Task** | |
| **Flyte SDK > Packages > flyte.remote > TaskDetails** | |
| **Flyte SDK > Packages > flyte.remote > TimeFilter** | Filter for time-based fields (e. |
| **Flyte SDK > Packages > flyte.remote > Trigger** | Represents a trigger in the Flyte platform. |
| **Flyte SDK > Packages > flyte.remote > User** | Represents a user in the Flyte platform. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Methods > auth_metadata()** | This context manager allows you to pass contextualized auth metadata downstream to the Flyte authentication system. |
| **Flyte SDK > Packages > flyte.remote > Methods > create_channel()** | Creates a new gRPC channel with appropriate authentication interceptors. |
| **Flyte SDK > Packages > flyte.remote > Methods > upload_dir()** | Uploads a directory to a remote location and returns the remote URI. |
| **Flyte SDK > Packages > flyte.remote > Methods > upload_file()** | Uploads a file to a remote location and returns the remote URI. |
## Methods
#### auth_metadata()
```python
def auth_metadata(
kv: typing.Tuple[str, str],
)
```
This context manager allows you to pass contextualized auth metadata downstream to the Flyte authentication system.
This is only useful if flyte.init_passthrough() has been called.
Example:
```python
flyte.init_passthrough("my-endpoint")
...
with auth_metadata((key1, value1), (key2, value2)):
...
```
| Parameter | Type | Description |
|-|-|-|
| `kv` | `typing.Tuple[str, str]` | |
#### create_channel()
```python
def create_channel(
endpoint: str | None,
api_key: str | None,
insecure: typing.Optional[bool],
insecure_skip_verify: typing.Optional[bool],
ca_cert_file_path: typing.Optional[str],
ssl_credentials: typing.Optional[ssl_channel_credentials],
grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]],
compression: typing.Optional[grpc.Compression],
http_session: httpx.AsyncClient | None,
proxy_command: typing.Optional[typing.List[str]],
rpc_retries: typing.Optional[int],
kwargs,
) -> grpc.aio._base_channel.Channel
```
Creates a new gRPC channel with appropriate authentication interceptors.
This function creates either a secure or insecure gRPC channel based on the provided parameters,
and adds authentication interceptors to the channel. If SSL credentials are not provided,
they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
The function is async because it may need to read certificate files asynchronously
and create authentication interceptors that perform async operations.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str \| None` | The endpoint URL for the gRPC channel |
| `api_key` | `str \| None` | API key for authentication; if provided, it will be used to detect the endpoint and credentials. |
| `insecure` | `typing.Optional[bool]` | Whether to use an insecure channel (no SSL) |
| `insecure_skip_verify` | `typing.Optional[bool]` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `typing.Optional[str]` | Path to CA certificate file for SSL verification |
| `ssl_credentials` | `typing.Optional[ssl_channel_credentials]` | Pre-configured SSL credentials for the channel |
| `grpc_options` | `typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]]` | Additional gRPC channel options |
| `compression` | `typing.Optional[grpc.Compression]` | Compression method for the channel |
| `http_session` | `httpx.AsyncClient \| None` | Pre-configured HTTP session to use for requests |
| `proxy_command` | `typing.Optional[typing.List[str]]` | List of strings for proxy command configuration |
| `rpc_retries` | `typing.Optional[int]` | Number of times to retry gRPC calls (flyte.init defaults to 3). None means do not install the interceptor at all. |
| `kwargs` | `**kwargs` | Additional arguments passed to various functions - For grpc.aio.insecure_channel/secure_channel: - root_certificates: Root certificates for SSL credentials - private_key: Private key for SSL credentials - certificate_chain: Certificate chain for SSL credentials - options: gRPC channel options - compression: gRPC compression method - For proxy configuration: - proxy_env: Dict of environment variables for proxy - proxy_timeout: Timeout for proxy connection - For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors): - auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow") - command: Command to execute for ExternalCommand authentication - client_id: Client ID for ClientSecret authentication - client_secret: Client secret for ClientSecret authentication - client_credentials_secret: Client secret for ClientSecret authentication (alias) - scopes: List of scopes to request during authentication - audience: Audience for the token - http_proxy_url: HTTP proxy URL - verify: Whether to verify SSL certificates - ca_cert_path: Optional path to CA certificate file - header_key: Header key to use for authentication - redirect_uri: OAuth2 redirect URI for PKCE authentication - add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token request - request_auth_code_params: Parameters to add to login URI opened in browser - request_access_token_params: Parameters to add when exchanging auth code for access token - refresh_access_token_params: Parameters to add when refreshing access token |
**Returns:** grpc.aio.Channel with authentication interceptors configured
#### upload_dir()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await upload_dir.aio()`.
```python
def upload_dir(
dir_path: pathlib.Path,
verify: bool,
prefix: str | None,
) -> str
```
Uploads a directory to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `dir_path` | `pathlib.Path` | The directory path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. |
| `prefix` | `str \| None` | |
**Returns:** The remote URI of the uploaded directory.
#### upload_file()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await upload_file.aio()`.
```python
def upload_file(
fp: pathlib.Path,
verify: bool,
fname: str | None,
) -> typing.Tuple[str, str]
```
Uploads a file to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `fp` | `pathlib.Path` | The file path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. |
| `fname` | `str \| None` | Optional file name for the remote path. |
**Returns:** Tuple of (MD5 digest hex string, remote native URL).
## Subpages
- **Flyte SDK > Packages > flyte.remote > Action**
- **Flyte SDK > Packages > flyte.remote > ActionDetails**
- **Flyte SDK > Packages > flyte.remote > ActionInputs**
- **Flyte SDK > Packages > flyte.remote > ActionOutputs**
- **Flyte SDK > Packages > flyte.remote > App**
- **Flyte SDK > Packages > flyte.remote > Project**
- **Flyte SDK > Packages > flyte.remote > Run**
- **Flyte SDK > Packages > flyte.remote > RunDetails**
- **Flyte SDK > Packages > flyte.remote > Secret**
- **Flyte SDK > Packages > flyte.remote > Task**
- **Flyte SDK > Packages > flyte.remote > TaskDetails**
- **Flyte SDK > Packages > flyte.remote > TimeFilter**
- **Flyte SDK > Packages > flyte.remote > Trigger**
- **Flyte SDK > Packages > flyte.remote > User**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/action ===
# Action
**Package:** `flyte.remote`
A class representing an action. It is used to manage the "execution" of a task and its state on the remote API.
From a datamodel perspective, a Run consists of actions. All actions are linearly nested under a parent action.
Actions have unique auto-generated identifiers, that are unique within a parent action.
<pre>
run
- a0
- action1 under a0
- action2 under a0
- action1 under action2 under a0
- action2 under action1 under action2 under a0
- ...
- ...
</pre>
## Parameters
```python
class Action(
pb2: run_definition_pb2.Action,
_details: ActionDetails | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `run_definition_pb2.Action` | |
| `_details` | `ActionDetails \| None` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `action_id` | `None` | Get the action ID. |
| `name` | `None` | Get the name of the action. |
| `phase` | `None` | Get the phase of the action. |
| `raw_phase` | `None` | Get the raw phase of the action. |
| `run_name` | `None` | Get the name of the run. |
| `start_time` | `None` | Get the start time of the action. |
| `task_name` | `None` | Get the name of the task. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Action > Methods > abort()** | Aborts / Terminates the action. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > details()** | Get the details of the action. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > done()** | Check if the action is done. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > get()** | Get a run by its ID or name. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > get_logs()** | Get logs for the action as an iterator of strings. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > listall()** | Get all actions for a given run. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > show_logs()** | Display logs for the action. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > sync()** | Sync the action with the remote server. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > wait()** | Wait for the run to complete, displaying a rich progress panel with status transitions,. |
| **Flyte SDK > Packages > flyte.remote > Action > Methods > watch()** | Watch the action for updates, updating the internal Action state with latest details. |
### abort()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .abort.aio()`.
```python
def abort(
reason: str,
)
```
Aborts / Terminates the action.
| Parameter | Type | Description |
|-|-|-|
| `reason` | `str` | |
### details()
```python
def details()
```
Get the details of the action. This is a placeholder for getting the action details.
### done()
```python
def done()
```
Check if the action is done.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Action.get.aio()`.
```python
def get(
cls,
uri: str | None,
run_name: str | None,
name: str | None,
) -> Action
```
Get a run by its ID or name. If both are provided, the ID will take precedence.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `uri` | `str \| None` | The URI of the action. |
| `run_name` | `str \| None` | The name of the action. |
| `name` | `str \| None` | The name of the action. |
### get_logs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .get_logs.aio()`.
```python
def get_logs(
attempt: int | None,
filter_system: bool,
show_ts: bool,
) -> AsyncGenerator[str, None]
```
Get logs for the action as an iterator of strings.
Can be called synchronously (returns `Iterator[str]`) or asynchronously
via `.aio()` (returns `AsyncIterator[str]`).
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | The attempt number to retrieve logs for (defaults to latest attempt). |
| `filter_system` | `bool` | If True, filter out system-generated log lines. |
| `show_ts` | `bool` | If True, prefix each line with an ISO-8601 timestamp. |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Action.listall.aio()`.
```python
def listall(
cls,
for_run_name: str,
in_phase: Tuple[ActionPhase | str, ...] | None,
sort_by: Tuple[str, Literal['asc', 'desc']] | None,
created_at: TimeFilter | None,
updated_at: TimeFilter | None,
) -> Union[Iterator[Action], AsyncIterator[Action]]
```
Get all actions for a given run.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `for_run_name` | `str` | The name of the run. |
| `in_phase` | `Tuple[ActionPhase \| str, ...] \| None` | Filter actions by one or more phases. |
| `sort_by` | `Tuple[str, Literal['asc', 'desc']] \| None` | The sorting criteria for the project list, in the format (field, order). |
| `created_at` | `TimeFilter \| None` | Filter actions by creation time range. |
| `updated_at` | `TimeFilter \| None` | Filter actions by last-update time range. |
**Returns:** An iterator of actions.
### show_logs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .show_logs.aio()`.
```python
def show_logs(
attempt: int | None,
max_lines: int,
show_ts: bool,
raw: bool,
filter_system: bool,
)
```
Display logs for the action.
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | The attempt number to show logs for (defaults to latest attempt). |
| `max_lines` | `int` | Maximum number of log lines to display in the viewer. |
| `show_ts` | `bool` | Whether to show timestamps with each log line. |
| `raw` | `bool` | If True, print logs directly without the interactive viewer. |
| `filter_system` | `bool` | If True, filter out system-generated log lines. |
### sync()
```python
def sync()
```
Sync the action with the remote server. This is a placeholder for syncing the action.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### wait()
```python
def wait(
quiet: bool,
wait_for: WaitFor,
)
```
Wait for the run to complete, displaying a rich progress panel with status transitions,
time elapsed, and error details in case of failure.
| Parameter | Type | Description |
|-|-|-|
| `quiet` | `bool` | |
| `wait_for` | `WaitFor` | |
### watch()
```python
def watch(
cache_data_on_done: bool,
wait_for: WaitFor,
) -> AsyncGenerator[ActionDetails, None]
```
Watch the action for updates, updating the internal Action state with latest details.
This method updates both the cached details and the protobuf representation,
ensuring that properties like `phase` reflect the current state.
| Parameter | Type | Description |
|-|-|-|
| `cache_data_on_done` | `bool` | |
| `wait_for` | `WaitFor` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/actiondetails ===
# ActionDetails
**Package:** `flyte.remote`
A class representing an action. It is used to manage the run of a task and its state on the remote Union API.
## Parameters
```python
class ActionDetails(
pb2: run_definition_pb2.ActionDetails,
_inputs: ActionInputs | None,
_outputs: ActionOutputs | None,
_preserve_original_types: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `run_definition_pb2.ActionDetails` | |
| `_inputs` | `ActionInputs \| None` | |
| `_outputs` | `ActionOutputs \| None` | |
| `_preserve_original_types` | `bool` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `abort_info` | `None` | Get the abort information if the action was aborted, otherwise returns None. |
| `action_id` | `None` | Get the action ID. |
| `attempts` | `None` | Get the number of attempts of the action. |
| `error_info` | `None` | Get the error information if the action failed, otherwise returns None. |
| `initializing_time` | `None` | Get the time spent in the INITIALIZING phase for the latest attempt. |
| `is_running` | `None` | Check if the action is currently running. |
| `metadata` | `None` | Get the metadata of the action. |
| `name` | `None` | Get the name of the action. |
| `phase` | `None` | Get the phase of the action. |
| `phase_durations` | `None` | Get the duration spent in each phase as a dictionary. Returns a mapping of ActionPhase to timedelta for the latest attempt. This provides an easy way to see how long was spent queued, initializing, running, etc. Example: >>> action = Action.get(run_name="my-run", name="my-action") >>> details = action.details() >>> durations = details.phase_durations >>> print(f"Queued: {durations.get(ActionPhase.QUEUED, timedelta(0)).total_seconds()}s") >>> print(f"Running: {durations.get(ActionPhase.RUNNING, timedelta(0)).total_seconds()}s") |
| `queued_time` | `None` | Get the time spent in the QUEUED phase for the latest attempt. |
| `raw_phase` | `None` | Get the raw phase of the action. |
| `run_name` | `None` | Get the name of the run. |
| `running_time` | `None` | Get the time spent in the RUNNING phase for the latest attempt. |
| `runtime` | `None` | Get the runtime of the action. |
| `status` | `None` | Get the status of the action. |
| `task_name` | `None` | Get the name of the task. |
| `waiting_for_resources_time` | `None` | Get the time spent in the WAITING_FOR_RESOURCES phase for the latest attempt. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > done()** | Check if the action is in a terminal state (completed or failed). |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > get()** | Get a run by its ID or name. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > get_details()** | Get the details of the action. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > get_phase_transitions()** | Get the phase transitions for a specific attempt, showing the granular breakdown. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > inputs()** | Return the inputs of the action. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > logs_available()** | Check if logs are available for the action, optionally for a specific attempt. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > outputs()** | Returns the outputs of the action, returns instantly if outputs are already cached, else fetches them and. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > watch()** | Watch the action for updates. |
| **Flyte SDK > Packages > flyte.remote > ActionDetails > Methods > watch_updates()** | Watch for updates to the action details, yielding each update until the action is done. |
### done()
```python
def done()
```
Check if the action is in a terminal state (completed or failed). This is a placeholder for checking the
action state.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ActionDetails.get.aio()`.
```python
def get(
cls,
uri: str | None,
run_name: str | None,
name: str | None,
) -> ActionDetails
```
Get a run by its ID or name. If both are provided, the ID will take precedence.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `uri` | `str \| None` | The URI of the action. |
| `run_name` | `str \| None` | The name of the run. |
| `name` | `str \| None` | The name of the action. |
### get_details()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ActionDetails.get_details.aio()`.
```python
def get_details(
cls,
action_id: identifier_pb2.ActionIdentifier,
) -> ActionDetails
```
Get the details of the action. This is a placeholder for getting the action details.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `action_id` | `identifier_pb2.ActionIdentifier` | |
### get_phase_transitions()
```python
def get_phase_transitions(
attempt: int | None,
) -> List[PhaseTransitionInfo]
```
Get the phase transitions for a specific attempt, showing the granular breakdown
of time spent in each phase (queued, initializing, running, etc.).
Example:
>>> action = Action.get(run_name="my-run", name="my-action")
>>> details = action.details()
>>> transitions = details.get_phase_transitions()
>>> for t in transitions:
... print(f"{t.phase}: {t.duration.total_seconds()}s")
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | The attempt number (1-indexed). If None, uses the latest attempt. |
**Returns**
List of PhaseTransitionInfo objects, one for each phase the action went through.
### inputs()
```python
def inputs()
```
Return the inputs of the action.
Will return instantly if inputs are available else will fetch and return.
### logs_available()
```python
def logs_available(
attempt: int | None,
) -> bool
```
Check if logs are available for the action, optionally for a specific attempt.
If attempt is None, it checks for the latest attempt.
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | |
### outputs()
```python
def outputs()
```
Returns the outputs of the action, returns instantly if outputs are already cached, else fetches them and
returns. If Action is not in a terminal state, raise a RuntimeError.
**Returns:** ActionOutputs
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### watch()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ActionDetails.watch.aio()`.
```python
def watch(
cls,
action_id: identifier_pb2.ActionIdentifier,
) -> AsyncIterator[ActionDetails]
```
Watch the action for updates. This is a placeholder for watching the action.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `action_id` | `identifier_pb2.ActionIdentifier` | |
### watch_updates()
```python
def watch_updates(
cache_data_on_done: bool,
) -> AsyncGenerator[ActionDetails, None]
```
Watch for updates to the action details, yielding each update until the action is done.
| Parameter | Type | Description |
|-|-|-|
| `cache_data_on_done` | `bool` | If True, cache inputs and outputs when the action completes. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/actioninputs ===
# ActionInputs
**Package:** `flyte.remote`
A class representing the inputs of an action. It is used to manage the inputs of a task and its state on the
remote Union API.
ActionInputs extends from a `UserDict` and hence is accessible like a dictionary
Example Usage:
```python
action = Action.get(...)
print(action.inputs())
```
Output:
```bash
{
"x": ...,
"y": ...,
}
```
## Parameters
```python
class ActionInputs(
pb2: common_pb2.Inputs,
data: Dict[str, Any],
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `common_pb2.Inputs` | |
| `data` | `Dict[str, Any]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > clear()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > copy()** | |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > fromkeys()** | |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > get()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > items()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > keys()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > pop()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > popitem()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > setdefault()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > update()** | D. |
| **Flyte SDK > Packages > flyte.remote > ActionInputs > Methods > values()** | D. |
### clear()
```python
def clear()
```
D.clear() -> None. Remove all items from D.
### copy()
```python
def copy()
```
### fromkeys()
```python
def fromkeys(
iterable,
value,
)
```
| Parameter | Type | Description |
|-|-|-|
| `iterable` | | |
| `value` | | |
### get()
```python
def get(
key,
default,
)
```
D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.
| Parameter | Type | Description |
|-|-|-|
| `key` | | |
| `default` | | |
### items()
```python
def items()
```
D.items() -> a set-like object providing a view on D's items
### keys()
```python
def keys()
```
D.keys() -> a set-like object providing a view on D's keys
### pop()
```python
def pop(
key,
default,
)
```
D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
| Parameter | Type | Description |
|-|-|-|
| `key` | | |
| `default` | | |
### popitem()
```python
def popitem()
```
D.popitem() -> (k, v), remove and return some (key, value) pair
as a 2-tuple; but raise KeyError if D is empty.
### setdefault()
```python
def setdefault(
key,
default,
)
```
D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D
| Parameter | Type | Description |
|-|-|-|
| `key` | | |
| `default` | | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
```python
def update(
other,
kwds,
)
```
D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
In either case, this is followed by: for k, v in F.items(): D[k] = v
| Parameter | Type | Description |
|-|-|-|
| `other` | | |
| `kwds` | | |
### values()
```python
def values()
```
D.values() -> an object providing a view on D's values
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/actionoutputs ===
# ActionOutputs
**Package:** `flyte.remote`
A class representing the outputs of an action. The outputs are by default represented as a Tuple. To access them,
you can simply read them as a tuple (assign to individual variables, use index to access) or you can use the
property `named_outputs` to retrieve a dictionary of outputs with keys that represent output names
which are usually auto-generated `o0, o1, o2, o3, ...`.
Example Usage:
```python
action = Action.get(...)
print(action.outputs())
```
Output:
```python
("val1", "val2", ...)
```
OR
```python
action = Action.get(...)
print(action.outputs().named_outputs)
```
Output:
```bash
{"o0": "val1", "o1": "val2", ...}
```
## Parameters
```python
class ActionOutputs(
pb2: common_pb2.Outputs,
data: Tuple[Any, ...],
fields: List[str] | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `common_pb2.Outputs` | |
| `data` | `Tuple[Any, ...]` | |
| `fields` | `List[str] \| None` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `named_outputs` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > ActionOutputs > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > ActionOutputs > Methods > to_json()** | Convert the object to a JSON string. |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/app ===
# App
**Package:** `flyte.remote`
## Parameters
```python
class App(
pb2: app_definition_pb2.App,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `app_definition_pb2.App` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `deployment_status` | `None` | Get the deployment status of the app |
| `desired_state` | `None` | Get the desired state of the app. |
| `endpoint` | `None` | Get the public endpoint URL of the app. |
| `name` | `None` | Get the name of the app. |
| `revision` | `None` | Get the revision number of the app. |
| `url` | `None` | Get the console URL for viewing the app. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > App > Methods > activate()** | Start the app. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > create()** | |
| **Flyte SDK > Packages > flyte.remote > App > Methods > deactivate()** | Stop the app. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > delete()** | Delete an app by name. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > ephemeral_ctx()** | Async context manager that activates the app and deactivates it when the context is exited. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > ephemeral_ctx_sync()** | Context manager that activates the app and deactivates it when the context is exited. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > get()** | Get an app by name. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > is_active()** | Check if the app is currently active or started. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > is_deactivated()** | Check if the app is currently deactivated or stopped. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > listall()** | |
| **Flyte SDK > Packages > flyte.remote > App > Methods > replace()** | Replace an existing app's that matches the given name, with a new spec and optionally labels. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > App > Methods > update()** | |
| **Flyte SDK > Packages > flyte.remote > App > Methods > watch()** | Watch for the app to reach activated or deactivated state. |
### activate()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .activate.aio()`.
```python
def activate(
wait: bool,
) -> App
```
Start the app
| Parameter | Type | Description |
|-|-|-|
| `wait` | `bool` | Wait for the app to reach activated state |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.create.aio()`.
```python
def create(
cls,
app: app_definition_pb2.App,
) -> App
```
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `app` | `app_definition_pb2.App` | |
### deactivate()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .deactivate.aio()`.
```python
def deactivate(
wait: bool,
) -> App
```
Stop the app
| Parameter | Type | Description |
|-|-|-|
| `wait` | `bool` | Wait for the app to reach the deactivated state |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.delete.aio()`.
```python
def delete(
cls,
name: str,
project: str | None,
domain: str | None,
)
```
Delete an app by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | The name of the app to delete. |
| `project` | `str \| None` | The name of the project to delete. |
| `domain` | `str \| None` | The name of the domain to delete. |
### ephemeral_ctx()
```python
def ephemeral_ctx()
```
Async context manager that activates the app and deactivates it when the context is exited.
### ephemeral_ctx_sync()
```python
def ephemeral_ctx_sync()
```
Context manager that activates the app and deactivates it when the context is exited.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.get.aio()`.
```python
def get(
cls,
name: str,
project: str | None,
domain: str | None,
) -> App
```
Get an app by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | The name of the app. |
| `project` | `str \| None` | The project of the app. |
| `domain` | `str \| None` | The domain of the app. |
**Returns:** The app remote object.
### is_active()
```python
def is_active()
```
Check if the app is currently active or started.
### is_deactivated()
```python
def is_deactivated()
```
Check if the app is currently deactivated or stopped.
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.listall.aio()`.
```python
def listall(
cls,
created_by_subject: str | None,
sort_by: Tuple[str, Literal['asc', 'desc']] | None,
limit: int,
) -> AsyncIterator[App]
```
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `created_by_subject` | `str \| None` | |
| `sort_by` | `Tuple[str, Literal['asc', 'desc']] \| None` | |
| `limit` | `int` | |
### replace()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.replace.aio()`.
```python
def replace(
cls,
name: str,
updated_app_spec: app_definition_pb2.Spec,
reason: str,
labels: Mapping[str, str] | None,
project: str | None,
domain: str | None,
) -> App
```
Replace an existing app's that matches the given name, with a new spec and optionally labels.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | Name of the new app |
| `updated_app_spec` | `app_definition_pb2.Spec` | Updated app spec |
| `reason` | `str` | |
| `labels` | `Mapping[str, str] \| None` | Optional labels for the new app |
| `project` | `str \| None` | Optional project for the new app |
| `domain` | `str \| None` | Optional domain for the new app |
**Returns:** A new app
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await App.update.aio()`.
```python
def update(
cls,
updated_app_proto: app_definition_pb2.App,
reason: str,
) -> App
```
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `updated_app_proto` | `app_definition_pb2.App` | |
| `reason` | `str` | |
### watch()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .watch.aio()`.
```python
def watch(
wait_for: WaitFor,
) -> App
```
Watch for the app to reach activated or deactivated state.
| Parameter | Type | Description |
|-|-|-|
| `wait_for` | `WaitFor` | ["activated", "deactivated"] Returns: The app in the desired state. Raises: RuntimeError if the app did not reach desired state and failed! |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/project ===
# Project
**Package:** `flyte.remote`
A class representing a project in the Union API.
## Parameters
```python
class Project(
pb2: project_service_pb2.Project,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `project_service_pb2.Project` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Project > Methods > archive()** | Archive this project. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > create()** | Create a new project. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > get()** | Get a project by name. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > listall()** | List all projects. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > unarchive()** | Unarchive (activate) this project. |
| **Flyte SDK > Packages > flyte.remote > Project > Methods > update()** | Update an existing project. |
### archive()
```python
def archive()
```
Archive this project.
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Project.create.aio()`.
```python
def create(
cls,
id: str,
name: str,
description: str,
labels: Dict[str, str] | None,
) -> Project
```
Create a new project.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `id` | `str` | The unique identifier for the project. |
| `name` | `str` | The display name for the project. |
| `description` | `str` | A description for the project. |
| `labels` | `Dict[str, str] \| None` | Optional key-value labels for the project. |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Project.get.aio()`.
```python
def get(
cls,
name: str,
) -> Project
```
Get a project by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | The name of the project. |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Project.listall.aio()`.
```python
def listall(
cls,
filters: str | None,
sort_by: Tuple[str, Literal['asc', 'desc']] | None,
archived: bool,
) -> Union[AsyncIterator[Project], Iterator[Project]]
```
List all projects.
By default, lists active (unarchived) projects. Set `archived=True` to list
archived projects instead.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `filters` | `str \| None` | The filters to apply to the project list. |
| `sort_by` | `Tuple[str, Literal['asc', 'desc']] \| None` | The sorting criteria for the project list, in the format (field, order). |
| `archived` | `bool` | If True, list archived projects. If False (default), list active projects. |
**Returns:** An iterator of projects.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### unarchive()
```python
def unarchive()
```
Unarchive (activate) this project.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Project.update.aio()`.
```python
def update(
cls,
id: str,
name: str | None,
description: str | None,
labels: Dict[str, str] | None,
state: Literal['archived', 'active'] | None,
) -> Project
```
Update an existing project.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `id` | `str` | The id of the project to update. |
| `name` | `str \| None` | New display name. If None, the existing name is preserved. |
| `description` | `str \| None` | New description. If None, the existing description is preserved. |
| `labels` | `Dict[str, str] \| None` | New labels. If None, the existing labels are preserved. |
| `state` | `Literal['archived', 'active'] \| None` | "archived" or "active". If None, the existing state is preserved. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/run ===
# Run
**Package:** `flyte.remote`
A class representing a run of a task. It is used to manage the run of a task and its state on the remote
Union API.
## Parameters
```python
class Run(
pb2: run_definition_pb2.Run,
_details: RunDetails | None,
_preserve_original_types: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `run_definition_pb2.Run` | |
| `_details` | `RunDetails \| None` | |
| `_preserve_original_types` | `bool` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `name` | `None` | Get the name of the run. |
| `phase` | `None` | Get the phase of the run. |
| `raw_phase` | `None` | Get the raw phase of the run. |
| `url` | `None` | Get the URL of the run. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Run > Methods > abort()** | Aborts / Terminates the run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > details()** | Get the details of the run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > done()** | Check if the run is done. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > get()** | Get the current run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > get_debug_url()** | Get the debug URL of the run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > get_logs()** | Get logs for the run as an iterator of strings. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > inputs()** | Get the inputs of the run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > listall()** | Get all runs for the current project and domain. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > outputs()** | Get the outputs of the run. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > show_logs()** | |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > sync()** | Sync the run with the remote server. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > wait()** | Wait for the run to complete, displaying a rich progress panel with status transitions,. |
| **Flyte SDK > Packages > flyte.remote > Run > Methods > watch()** | Watch the run for updates, updating the internal Run state with latest details. |
### abort()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .abort.aio()`.
```python
def abort(
reason: str,
)
```
Aborts / Terminates the run.
| Parameter | Type | Description |
|-|-|-|
| `reason` | `str` | |
### details()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .details.aio()`.
```python
def details()
```
Get the details of the run. This is a placeholder for getting the run details.
### done()
```python
def done()
```
Check if the run is done.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Run.get.aio()`.
```python
def get(
cls,
name: str,
) -> Run
```
Get the current run.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
**Returns:** The current run.
### get_debug_url()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .get_debug_url.aio()`.
```python
def get_debug_url()
```
Get the debug URL of the run. Returns `None` if the VS Code
Debugger log entry is not yet available in the action details.
### get_logs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .get_logs.aio()`.
```python
def get_logs(
attempt: int | None,
filter_system: bool,
show_ts: bool,
) -> AsyncGenerator[str, None]
```
Get logs for the run as an iterator of strings.
Can be called synchronously (returns `Iterator[str]`) or asynchronously
via `.aio()` (returns `AsyncIterator[str]`).
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | The attempt number to retrieve logs for (defaults to latest attempt). |
| `filter_system` | `bool` | If True, filter out system-generated log lines. |
| `show_ts` | `bool` | If True, prefix each line with an ISO-8601 timestamp. |
### inputs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .inputs.aio()`.
```python
def inputs()
```
Get the inputs of the run. This is a placeholder for getting the run inputs.
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Run.listall.aio()`.
```python
def listall(
cls,
in_phase: Tuple[ActionPhase | str, ...] | None,
task_name: str | None,
task_version: str | None,
created_by_subject: str | None,
sort_by: Tuple[str, Literal['asc', 'desc']] | None,
limit: int,
project: str | None,
domain: str | None,
created_at: TimeFilter | None,
updated_at: TimeFilter | None,
) -> AsyncIterator[Run]
```
Get all runs for the current project and domain.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `in_phase` | `Tuple[ActionPhase \| str, ...] \| None` | Filter runs by one or more phases. |
| `task_name` | `str \| None` | Filter runs by task name. |
| `task_version` | `str \| None` | Filter runs by task version. |
| `created_by_subject` | `str \| None` | Filter runs by the subject that created them. (this is not username, but the subject) |
| `sort_by` | `Tuple[str, Literal['asc', 'desc']] \| None` | The sorting criteria for the Run list, in the format (field, order). |
| `limit` | `int` | The maximum number of runs to return. |
| `project` | `str \| None` | The project to list runs for. Defaults to the globally configured project. |
| `domain` | `str \| None` | The domain to list runs for. Defaults to the globally configured domain. |
| `created_at` | `TimeFilter \| None` | Filter runs by creation time range. |
| `updated_at` | `TimeFilter \| None` | Filter runs by last-update time range. |
**Returns:** An iterator of runs.
### outputs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .outputs.aio()`.
```python
def outputs()
```
Get the outputs of the run. This is a placeholder for getting the run outputs.
### show_logs()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .show_logs.aio()`.
```python
def show_logs(
attempt: int | None,
max_lines: int,
show_ts: bool,
raw: bool,
filter_system: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `attempt` | `int \| None` | |
| `max_lines` | `int` | |
| `show_ts` | `bool` | |
| `raw` | `bool` | |
| `filter_system` | `bool` | |
### sync()
```python
def sync()
```
Sync the run with the remote server. This is a placeholder for syncing the run.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### wait()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .wait.aio()`.
```python
def wait(
quiet: bool,
wait_for: Literal['terminal', 'running'],
)
```
Wait for the run to complete, displaying a rich progress panel with status transitions,
time elapsed, and error details in case of failure.
This method updates the Run's internal state, ensuring that properties like
`run.action.phase` reflect the final state after waiting completes.
| Parameter | Type | Description |
|-|-|-|
| `quiet` | `bool` | |
| `wait_for` | `Literal['terminal', 'running']` | |
### watch()
```python
def watch(
cache_data_on_done: bool,
) -> AsyncGenerator[ActionDetails, None]
```
Watch the run for updates, updating the internal Run state with latest details.
This method updates the Run's action state, ensuring that properties like
`run.action.phase` reflect the current state after watching.
| Parameter | Type | Description |
|-|-|-|
| `cache_data_on_done` | `bool` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/rundetails ===
# RunDetails
**Package:** `flyte.remote`
A class representing a run of a task. It is used to manage the run of a task and its state on the remote
Union API.
## Parameters
```python
class RunDetails(
pb2: run_definition_pb2.RunDetails,
_preserve_original_types: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `run_definition_pb2.RunDetails` | |
| `_preserve_original_types` | `bool` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `action_id` | `None` | Get the action ID. |
| `name` | `None` | Get the name of the action. |
| `task_name` | `None` | Get the name of the task. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > done()** | Check if the run is in a terminal state (completed or failed). |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > get()** | Get a run by its ID or name. |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > get_details()** | Get the details of the run. |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > inputs()** | Placeholder for inputs. |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > outputs()** | Placeholder for outputs. |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > RunDetails > Methods > to_json()** | Convert the object to a JSON string. |
### done()
```python
def done()
```
Check if the run is in a terminal state (completed or failed). This is a placeholder for checking the
run state.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await RunDetails.get.aio()`.
```python
def get(
cls,
name: str | None,
) -> RunDetails
```
Get a run by its ID or name. If both are provided, the ID will take precedence.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str \| None` | The name of the run. |
### get_details()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await RunDetails.get_details.aio()`.
```python
def get_details(
cls,
run_id: identifier_pb2.RunIdentifier,
) -> RunDetails
```
Get the details of the run. This is a placeholder for getting the run details.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `run_id` | `identifier_pb2.RunIdentifier` | |
### inputs()
```python
def inputs()
```
Placeholder for inputs. This can be extended to handle inputs from the run context.
### outputs()
```python
def outputs()
```
Placeholder for outputs. This can be extended to handle outputs from the run context.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/secret ===
# Secret
**Package:** `flyte.remote`
## Parameters
```python
class Secret(
pb2: definition_pb2.Secret,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `definition_pb2.Secret` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `name` | `None` | Get the name of the secret. |
| `type` | `None` | Get the type of the secret as a string ("regular" or "image_pull"). |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > create()** | Create a new secret. |
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > delete()** | Delete a secret by name. |
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > get()** | Retrieve a secret by name. |
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > listall()** | List all secrets in the current project and domain. |
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Secret > Methods > to_json()** | Convert the object to a JSON string. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Secret.create.aio()`.
```python
def create(
cls,
name: str,
value: Union[str, bytes],
type: SecretTypes,
)
```
Create a new secret.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | The name of the secret. |
| `value` | `Union[str, bytes]` | The secret value as a string or bytes. |
| `type` | `SecretTypes` | Type of secret - either "regular" or "image_pull". |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Secret.delete.aio()`.
```python
def delete(
cls,
name,
)
```
Delete a secret by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | | The name of the secret to delete. |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Secret.get.aio()`.
```python
def get(
cls,
name: str,
) -> Secret
```
Retrieve a secret by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | The name of the secret to retrieve. |
**Returns:** A Secret object.
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Secret.listall.aio()`.
```python
def listall(
cls,
limit: int,
) -> AsyncIterator[Secret]
```
List all secrets in the current project and domain.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | Maximum number of secrets to return per page. |
**Returns:** An async iterator of Secret objects.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/task ===
# Task
**Package:** `flyte.remote`
## Parameters
```python
class Task(
pb2: task_definition_pb2.Task,
)
```
Initialize a Task object.
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `task_definition_pb2.Task` | The task protobuf definition. |
## Properties
| Property | Type | Description |
|-|-|-|
| `name` | `None` | The name of the task. |
| `url` | `None` | Get the console URL for viewing the task. |
| `version` | `None` | The version of the task. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Task > Methods > get()** | Get a task by its ID or name. |
| **Flyte SDK > Packages > flyte.remote > Task > Methods > listall()** | Get all runs for the current project and domain. |
| **Flyte SDK > Packages > flyte.remote > Task > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Task > Methods > to_json()** | Convert the object to a JSON string. |
### get()
```python
def get(
name: str,
project: str | None,
domain: str | None,
version: str | None,
auto_version: AutoVersioning | None,
) -> LazyEntity
```
Get a task by its ID or name. If both are provided, the ID will take precedence.
Either version or auto_version are required parameters.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the task. |
| `project` | `str \| None` | The project of the task. |
| `domain` | `str \| None` | The domain of the task. |
| `version` | `str \| None` | The version of the task. |
| `auto_version` | `AutoVersioning \| None` | If set to "latest", the latest-by-time ordered from now, version of the task will be used. If set to "current", the version will be derived from the callee tasks context. This is useful if you are deploying all environments with the same version. If auto_version is current, you can only access the task from within a task context. |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Task.listall.aio()`.
```python
def listall(
cls,
by_task_name: str | None,
by_task_env: str | None,
project: str | None,
domain: str | None,
sort_by: Tuple[str, Literal['asc', 'desc']] | None,
limit: int,
) -> Union[AsyncIterator[Task], Iterator[Task]]
```
Get all runs for the current project and domain.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `by_task_name` | `str \| None` | If provided, only tasks with this name will be returned. |
| `by_task_env` | `str \| None` | If provided, only tasks with this environment prefix will be returned. |
| `project` | `str \| None` | The project to filter tasks by. If None, the current project will be used. |
| `domain` | `str \| None` | The domain to filter tasks by. If None, the current domain will be used. |
| `sort_by` | `Tuple[str, Literal['asc', 'desc']] \| None` | The sorting criteria for the project list, in the format (field, order). |
| `limit` | `int` | The maximum number of tasks to return. |
**Returns:** An iterator of runs.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/taskdetails ===
# TaskDetails
**Package:** `flyte.remote`
## Parameters
```python
class TaskDetails(
pb2: task_definition_pb2.TaskDetails,
max_inline_io_bytes: int,
overriden_queue: Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `task_definition_pb2.TaskDetails` | |
| `max_inline_io_bytes` | `int` | |
| `overriden_queue` | `Optional[str]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `cache` | `None` | The cache policy of the task. |
| `default_input_args` | `None` | The default input arguments of the task. |
| `interface` | `None` | The interface of the task. |
| `name` | `None` | The name of the task. |
| `queue` | `None` | Get the queue name to use for task execution, if overridden. |
| `required_args` | `None` | The required input arguments of the task. |
| `resources` | `None` | Get the resource requests and limits for the task as a tuple (requests, limits). |
| `secrets` | `None` | Get the list of secret keys required by the task. |
| `task_type` | `None` | The type of the task. |
| `version` | `None` | The version of the task. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > TaskDetails > Methods > fetch()** | |
| **Flyte SDK > Packages > flyte.remote > TaskDetails > Methods > get()** | Get a task by its ID or name. |
| **Flyte SDK > Packages > flyte.remote > TaskDetails > Methods > override()** | Create a new TaskDetails with overridden properties. |
| **Flyte SDK > Packages > flyte.remote > TaskDetails > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > TaskDetails > Methods > to_json()** | Convert the object to a JSON string. |
### fetch()
```python
def fetch(
name: str,
project: str | None,
domain: str | None,
version: str | None,
auto_version: AutoVersioning | None,
) -> TaskDetails
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `version` | `str \| None` | |
| `auto_version` | `AutoVersioning \| None` | |
### get()
```python
def get(
name: str,
project: str | None,
domain: str | None,
version: str | None,
auto_version: AutoVersioning | None,
) -> LazyEntity
```
Get a task by its ID or name. If both are provided, the ID will take precedence.
Either version or auto_version are required parameters.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the task. |
| `project` | `str \| None` | The project of the task. |
| `domain` | `str \| None` | The domain of the task. |
| `version` | `str \| None` | The version of the task. |
| `auto_version` | `AutoVersioning \| None` | If set to "latest", the latest-by-time ordered from now, version of the task will be used. If set to "current", the version will be derived from the callee tasks context. This is useful if you are deploying all environments with the same version. If auto_version is current, you can only access the task from within a task context. |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[flyte.Resources],
retries: Union[int, flyte.RetryStrategy],
timeout: Optional[flyte.TimeoutType],
env_vars: Optional[Dict[str, str]],
secrets: Optional[flyte.SecretRequest],
max_inline_io_bytes: Optional[int],
cache: Optional[flyte.Cache],
queue: Optional[str],
kwargs: **kwargs,
) -> TaskDetails
```
Create a new TaskDetails with overridden properties.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional short name for the task. |
| `resources` | `Optional[flyte.Resources]` | Optional resource requirements. |
| `retries` | `Union[int, flyte.RetryStrategy]` | Number of retries or retry strategy. |
| `timeout` | `Optional[flyte.TimeoutType]` | Execution timeout. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables to set. |
| `secrets` | `Optional[flyte.SecretRequest]` | Secret requests for the task. |
| `max_inline_io_bytes` | `Optional[int]` | Maximum inline I/O size in bytes. |
| `cache` | `Optional[flyte.Cache]` | Cache configuration. |
| `queue` | `Optional[str]` | Queue name for task execution. |
| `kwargs` | `**kwargs` | |
**Returns:** A new TaskDetails instance with the overrides applied.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/timefilter ===
# TimeFilter
**Package:** `flyte.remote`
Filter for time-based fields (e.g. created_at, updated_at).
## Parameters
```python
class TimeFilter(
after: datetime.datetime | None,
before: datetime.datetime | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `after` | `datetime.datetime \| None` | Return only entries at or after this datetime (inclusive). |
| `before` | `datetime.datetime \| None` | Return only entries before this datetime (exclusive). |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/trigger ===
# Trigger
**Package:** `flyte.remote`
Represents a trigger in the Flyte platform.
## Parameters
```python
class Trigger(
pb2: trigger_definition_pb2.Trigger,
details: TriggerDetails | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `trigger_definition_pb2.Trigger` | |
| `details` | `TriggerDetails \| None` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `automation_spec` | `None` | Get the automation specification for the trigger. |
| `id` | `None` | Get the unique identifier for the trigger. |
| `is_active` | `None` | Check if the trigger is currently active. |
| `name` | `None` | Get the name of the trigger. |
| `task_name` | `None` | Get the name of the task associated with this trigger. |
| `url` | `None` | Get the console URL for viewing the trigger. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > create()** | Create a new trigger in the Flyte platform. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > delete()** | Delete a trigger by its name. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > get()** | Retrieve a trigger by its name and associated task name. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > get_details()** | Get detailed information about this trigger. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > listall()** | List all triggers associated with a specific task or all tasks if no task name is provided. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > to_json()** | Convert the object to a JSON string. |
| **Flyte SDK > Packages > flyte.remote > Trigger > Methods > update()** | Pause a trigger by its name and associated task name. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Trigger.create.aio()`.
```python
def create(
cls,
trigger: flyte.Trigger,
task_name: str,
task_version: str | None,
) -> Trigger
```
Create a new trigger in the Flyte platform.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `trigger` | `flyte.Trigger` | The flyte.Trigger object containing the trigger definition. |
| `task_name` | `str` | Optional name of the task to associate with the trigger. |
| `task_version` | `str \| None` | |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Trigger.delete.aio()`.
```python
def delete(
cls,
name: str,
task_name: str,
project: str | None,
domain: str | None,
)
```
Delete a trigger by its name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `task_name` | `str` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Trigger.get.aio()`.
```python
def get(
cls,
name: str,
task_name: str,
) -> TriggerDetails
```
Retrieve a trigger by its name and associated task name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `task_name` | `str` | |
### get_details()
```python
def get_details()
```
Get detailed information about this trigger.
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Trigger.listall.aio()`.
```python
def listall(
cls,
task_name: str | None,
task_version: str | None,
limit: int,
) -> AsyncIterator[Trigger]
```
List all triggers associated with a specific task or all tasks if no task name is provided.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `task_name` | `str \| None` | |
| `task_version` | `str \| None` | |
| `limit` | `int` | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Trigger.update.aio()`.
```python
def update(
cls,
name: str,
task_name: str,
active: bool,
)
```
Pause a trigger by its name and associated task name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `task_name` | `str` | |
| `active` | `bool` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote/user ===
# User
**Package:** `flyte.remote`
Represents a user in the Flyte platform.
## Parameters
```python
class User(
pb2: UserInfoResponse,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `UserInfoResponse` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > User > Methods > get()** | Fetches information about the currently logged in user. |
| **Flyte SDK > Packages > flyte.remote > User > Methods > name()** | Get the name of the user. |
| **Flyte SDK > Packages > flyte.remote > User > Methods > subject()** | Get the subject identifier of the user. |
| **Flyte SDK > Packages > flyte.remote > User > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > User > Methods > to_json()** | Convert the object to a JSON string. |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await User.get.aio()`.
```python
def get(
cls,
) -> User
```
Fetches information about the currently logged in user.
Returns: A User object containing details about the user.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
### name()
```python
def name()
```
Get the name of the user.
### subject()
```python
def subject()
```
Get the subject identifier of the user.
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.report ===
# flyte.report
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > Report** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > Methods > current_report()** | Get the current report. |
| **Flyte SDK > Packages > flyte.report > Methods > flush()** | Flush the report. |
| **Flyte SDK > Packages > flyte.report > Methods > get_tab()** | Get a tab by name. |
| **Flyte SDK > Packages > flyte.report > Methods > log()** | Log content to the main tab. |
| **Flyte SDK > Packages > flyte.report > Methods > replace()** | Get the report. |
## Methods
#### current_report()
```python
def current_report()
```
Get the current report. This is a dummy report if not in a task context.
**Returns:** The current report.
#### flush()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flush.aio()`.
```python
def flush()
```
Flush the report.
#### get_tab()
```python
def get_tab(
name: str,
create_if_missing: bool,
) -> flyte.report._report.Tab
```
Get a tab by name. If the tab does not exist, create it.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the tab. |
| `create_if_missing` | `bool` | Whether to create the tab if it does not exist. |
**Returns:** The tab.
#### log()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await log.aio()`.
```python
def log(
content: str,
do_flush: bool,
)
```
Log content to the main tab. The content should be a valid HTML string, but not a complete HTML document,
as it will be inserted into a div.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | The content to log. |
| `do_flush` | `bool` | flush the report after logging. |
#### replace()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await replace.aio()`.
```python
def replace(
content: str,
do_flush: bool,
)
```
Get the report. Replaces the content of the main tab.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | |
| `do_flush` | `bool` | |
**Returns:** The report.
## Subpages
- **Flyte SDK > Packages > flyte.report > Report**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.report/report ===
# Report
**Package:** `flyte.report`
## Parameters
```python
class Report(
name: str,
tabs: typing.Dict[str, flyte.report._report.Tab],
template_path: pathlib.Path,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `tabs` | `typing.Dict[str, flyte.report._report.Tab]` | |
| `template_path` | `pathlib.Path` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > Report > Methods > get_final_report()** | Get the final report as a string. |
| **Flyte SDK > Packages > flyte.report > Report > Methods > get_tab()** | Get a tab by name. |
### get_final_report()
```python
def get_final_report()
```
Get the final report as a string.
**Returns:** The final report.
### get_tab()
```python
def get_tab(
name: str,
create_if_missing: bool,
) -> flyte.report._report.Tab
```
Get a tab by name. If the tab does not exist, create it.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the tab. |
| `create_if_missing` | `bool` | Whether to create the tab if it does not exist. |
**Returns:** The tab.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.sandbox ===
# flyte.sandbox
Sandbox utilities for running isolated code inside Flyte tasks.
Warning: Experimental feature: alpha β APIs may change without notice.
`flyte.sandbox` provides two distinct sandboxing approaches:
---
**1. Orchestration sandbox** β powered by Monty
Runs pure Python *orchestration logic* (control flow, routing, aggregation)
with zero overhead. The Monty runtime enforces strong restrictions:
no imports, no IO, no network access, microsecond startup. Used via
`@env.sandbox.orchestrator` or `flyte.sandbox.orchestrator_from_str()`.
Sandboxed orchestrators are:
- **Side-effect free**: No filesystem, network, or OS access
- **Microsecond startup**: No container spin-up β runs in the same process
- **Multiplexable**: Many orchestrators run safely on the same Python process
Example::
env = flyte.TaskEnvironment(name="my-env")
@env.sandbox.orchestrator
def route(x: int, y: int) -> int:
return add(x, y) # calls a worker task
pipeline = flyte.sandbox.orchestrator_from_str(
"add(x, y) * 2",
inputs={"x": int, "y": int},
output=int,
tasks=[add],
)
---
**2. Code sandbox** β arbitrary code in an isolated container
Runs arbitrary Python scripts or shell commands inside an ephemeral Docker
container. The image is built on demand from declared `packages` and
`system_packages`, executed once, then discarded. Network is blocked by
default (`block_network=True`), preventing outbound calls from untrusted
code. Used via `flyte.sandbox.create()`.
Three execution modes are supported:
- Code mode β provide Python source that runs with automatic input/output wiring.
- Verbatim mode β run a script that manages its own I/O via /var/inputs and /var/outputs.
- Command mode β execute an arbitrary command or entrypoint.
Examples
--------
Code mode
~~~~~~~~~
Provide Python code that uses inputs as variables and assigns
outputs as Python values.
_stats_code = """
import numpy as np
nums = np.array([float(v) for v in values.split(",")])
mean = float(np.mean(nums))
std = float(np.std(nums))
window_end = dt + delta
"""
stats_sandbox = flyte.sandbox.create(
name="numpy-stats",
code=_stats_code,
inputs={
"values": str,
"dt": datetime.datetime,
"delta": datetime.timedelta,
},
outputs={
"mean": float,
"std": float,
"window_end": datetime.datetime,
},
packages=["numpy"],
)
mean, std, window_end = await stats_sandbox.run.aio(
values="1,2,3,4,5",
dt=datetime.datetime(2024, 1, 1),
delta=datetime.timedelta(days=1),
)
Verbatim mode
~~~~~~~~~~~~~
Run a script that explicitly reads inputs from /var/inputs and
writes outputs to /var/outputs.
_etl_script = """ import json, pathlib
payload = json.loads(
pathlib.Path("/var/inputs/payload").read_text()
)
total = sum(payload["values"])
pathlib.Path("/var/outputs/total").write_text(str(total))
"""
etl_sandbox = flyte.sandbox.create(
name="etl-script",
code=_etl_script,
inputs={"payload": File},
outputs={"total": int},
auto_io=False,
)
Command mode
~~~~~~~~~~~~
Execute an arbitrary command inside the sandbox environment.
sandbox = flyte.sandbox.create(
name="test-runner",
command=["/bin/bash", "-c", "pytest /var/inputs/tests.py -q"],
inputs={"tests.py": File},
outputs={"exit_code": str},
)
Notes
-----
β’ Inputs are materialized under /var/inputs.
β’ Outputs must be written to /var/outputs.
β’ In code mode, inputs are available as Python variables and
scalar outputs are captured automatically.
β’ Additional Python dependencies can be specified via the
`packages` argument.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate** | A sandboxed task created from a code string rather than a decorated function. |
| **Flyte SDK > Packages > flyte.sandbox > ImageConfig** | Configuration for Docker image building at runtime. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedConfig** | Configuration for a sandboxed task executed via Monty. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate** | A task template that executes the function body in a Monty sandbox. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.sandbox > Methods > create()** | Create a stateless Python code sandbox. |
| **Flyte SDK > Packages > flyte.sandbox > Methods > orchestrate_local()** | One-shot local execution of a code string in the Monty sandbox. |
| **Flyte SDK > Packages > flyte.sandbox > Methods > orchestrator_from_str()** | Create a reusable sandboxed task from a code string. |
### Variables
| Property | Type | Description |
|-|-|-|
| `ORCHESTRATOR_SYNTAX_PROMPT` | `str` | |
| `sandbox_environment` | `TaskEnvironment` | |
## Methods
#### create()
```python
def create(
name: typing.Optional[str],
code: typing.Optional[str],
inputs: typing.Optional[dict[str, type]],
outputs: typing.Optional[dict[str, type]],
command: typing.Optional[list[str]],
arguments: typing.Optional[list[str]],
packages: typing.Optional[list[str]],
system_packages: typing.Optional[list[str]],
additional_commands: typing.Optional[list[str]],
resources: typing.Optional[flyte._resources.Resources],
image_config: typing.Optional[flyte.sandbox._code_sandbox.ImageConfig],
image_name: typing.Optional[str],
image: typing.Optional[str],
auto_io: bool,
retries: int,
timeout: typing.Optional[int],
env_vars: typing.Optional[dict[str, str]],
secrets: typing.Optional[list],
cache: str,
) -> flyte.sandbox._code_sandbox._Sandbox
```
Create a stateless Python code sandbox.
The sandbox is **stateless** β each invocation runs in a fresh, ephemeral
container. No filesystem state, environment variables or side effects
carry over between runs.
Three modes, mutually exclusive:
- **Auto-IO mode** (`code` provided, `auto_io=True`, default): write
just the business logic. Flyte auto-generates an argparse preamble so
declared inputs are available as local variables, and writes declared
scalar outputs to `/var/outputs/` automatically. No boilerplate needed.
- **Verbatim mode** (`code` provided, `auto_io=False`): run an
arbitrary Python script as-is. CLI args for declared inputs are still
forwarded, but the script handles all I/O itself (reading from
`/var/inputs/`, writing to `/var/outputs/<name>` manually).
- **Command mode** (`command` provided): run any shell command directly,
e.g. a compiled binary or a shell pipeline.
Call `.run()` on the returned sandbox to build the image and execute.
Example β auto-IO mode (default, no boilerplate)::
sandbox = flyte.sandbox.create(
name="double",
code="result = x * 2",
inputs={"x": int},
outputs={"result": int},
)
result = await sandbox.run.aio(x=21) # returns 42
Example β verbatim mode (complete Python script, full control)::
sandbox = flyte.sandbox.create(
name="etl",
code="""
import json, pathlib
data = json.loads(pathlib.Path("/var/inputs/payload").read_text())
pathlib.Path("/var/outputs/total").write_text(str(sum(data["values"])))
""",
inputs={"payload": File},
outputs={"total": int},
auto_io=False,
)
Example β command mode::
sandbox = flyte.sandbox.create(
name="test-runner",
command=["/bin/bash", "-c", pytest_cmd],
arguments=["_", "/var/inputs/solution.py", "/var/inputs/tests.py"],
inputs={"solution.py": File, "tests.py": File},
outputs={"exit_code": str},
)
| Parameter | Type | Description |
|-|-|-|
| `name` | `typing.Optional[str]` | Sandbox name. Derives task and image names. |
| `code` | `typing.Optional[str]` | Python source to run (auto-IO or verbatim mode). Mutually exclusive with `command`. |
| `inputs` | `typing.Optional[dict[str, type]]` | Input type declarations. Supported types: - Primitive: `int`, `float`, `str`, `bool` - Date/time: `datetime.datetime`, `datetime.timedelta` - IO handles: `flyte.io.File` (bind-mounted at `/var/inputs/<name>`; available as a path string in auto-IO mode) |
| `outputs` | `typing.Optional[dict[str, type]]` | Output type declarations. Supported types: - Primitive: `int`, `float`, `str`, `bool` - Date/time: `datetime.datetime` (ISO-8601), `datetime.timedelta` - IO handles: `flyte.io.File` (user code must write the file to `/var/outputs/<name>`) |
| `command` | `typing.Optional[list[str]]` | Entrypoint command (command mode). Mutually exclusive with `code`. |
| `arguments` | `typing.Optional[list[str]]` | Arguments forwarded to `command` (command mode only). |
| `packages` | `typing.Optional[list[str]]` | Python packages to install via pip. |
| `system_packages` | `typing.Optional[list[str]]` | System packages to install via apt. |
| `additional_commands` | `typing.Optional[list[str]]` | Extra Dockerfile `RUN` commands. |
| `resources` | `typing.Optional[flyte._resources.Resources]` | CPU / memory resources for the container. |
| `image_config` | `typing.Optional[flyte.sandbox._code_sandbox.ImageConfig]` | Registry and Python version settings. |
| `image_name` | `typing.Optional[str]` | Explicit image name, overrides the auto-generated one. |
| `image` | `typing.Optional[str]` | Pre-built image URI. Skips the build step if provided. |
| `auto_io` | `bool` | When `True` (default), Flyte wraps `code` with an auto-generated argparse preamble and output-writing epilogue so declared inputs are available as local variables and scalar outputs are collected automatically β no boilerplate needed. When `False`, `code` is run verbatim and must handle all I/O itself. |
| `retries` | `int` | Number of task retries on failure. |
| `timeout` | `typing.Optional[int]` | Task timeout in seconds. |
| `env_vars` | `typing.Optional[dict[str, str]]` | Environment variables available inside the container. |
| `secrets` | `typing.Optional[list]` | Flyte `flyte.Secret` objects to mount. |
| `cache` | `str` | Cache behaviour β `"auto"`, `"override"`, or `"disable"`. |
**Returns:** Configured sandbox ready to `.run()`.
#### orchestrate_local()
```python
def orchestrate_local(
source: str,
inputs: Dict[str, Any],
tasks: Optional[List[Any]],
timeout_ms: int,
) -> Any
```
One-shot local execution of a code string in the Monty sandbox.
Warning: Experimental feature: alpha β APIs may change without notice.
Sends the code + inputs to Monty and returns the result directly,
without creating a `TaskTemplate` or going through the controller.
The **last expression** in *source* becomes the return value::
result = await sandbox.orchestrate_local(
"add(x, y) * 2",
inputs={"x": 1, "y": 2},
tasks=[add],
)
# β 6
Parameters
----------
source:
Python code string to execute in the sandbox.
inputs:
Mapping of input names to their values.
tasks:
List of external functions (tasks, durable ops) available inside the
sandbox. Each item's `__name__` is used as the key.
timeout_ms:
Sandbox execution timeout in milliseconds.
| Parameter | Type | Description |
|-|-|-|
| `source` | `str` | |
| `inputs` | `Dict[str, Any]` | |
| `tasks` | `Optional[List[Any]]` | |
| `timeout_ms` | `int` | |
#### orchestrator_from_str()
```python
def orchestrator_from_str(
source: str,
inputs: Dict[str, type],
output: type,
tasks: Optional[List[Any]],
name: str,
timeout_ms: int,
cache: CacheRequest,
retries: int,
image: Optional[Any],
) -> CodeTaskTemplate
```
Create a reusable sandboxed task from a code string.
Warning: Experimental feature: alpha β APIs may change without notice.
The returned `CodeTaskTemplate` can be passed to `flyte.run()`
just like a decorated task.
The **last expression** in *source* becomes the return value::
pipeline = sandbox.orchestrator_from_str(
"add(x, y) * 2",
inputs={"x": int, "y": int},
output=int,
tasks=[add],
)
result = flyte.run(pipeline, x=1, y=2) # β 6
Parameters
----------
source:
Python code string to execute in the sandbox.
inputs:
Mapping of input names to their types.
output:
The return type (default `NoneType`).
tasks:
List of external functions (tasks, durable ops) available inside the
sandbox. Each item's `__name__` is used as the key.
name:
Task name (default `"sandboxed-code"`).
timeout_ms:
Sandbox execution timeout in milliseconds.
cache:
Cache policy for the task.
retries:
Number of retries on failure.
image:
Docker image to use. If not provided, a default Debian image with
`pydantic-monty` is created automatically.
| Parameter | Type | Description |
|-|-|-|
| `source` | `str` | |
| `inputs` | `Dict[str, type]` | |
| `output` | `type` | |
| `tasks` | `Optional[List[Any]]` | |
| `name` | `str` | |
| `timeout_ms` | `int` | |
| `cache` | `CacheRequest` | |
| `retries` | `int` | |
| `image` | `Optional[Any]` | |
## Subpages
- **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate**
- **Flyte SDK > Packages > flyte.sandbox > ImageConfig**
- **Flyte SDK > Packages > flyte.sandbox > SandboxedConfig**
- **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.sandbox/codetasktemplate ===
# CodeTaskTemplate
**Package:** `flyte.sandbox`
A sandboxed task created from a code string rather than a decorated function.
Unlike `SandboxedTaskTemplate` (which extracts source from a Python
function), this class accepts pre-transformed source code and an explicit
dict of external functions. It is constructed via `flyte.sandbox.orchestrator_from_str`.
## Parameters
```python
class CodeTaskTemplate(
name: str,
interface: NativeInterface,
short_name: str,
task_type: str,
task_type_version: int,
image: Union[str, Image, Literal['auto']] | None,
resources: Optional[Resources],
cache: CacheRequest,
interruptible: bool,
retries: Union[int, RetryStrategy],
reusable: Union[ReusePolicy, None],
docs: Optional[Documentation],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
timeout: Optional[TimeoutType],
pod_template: Optional[Union[str, PodTemplate]],
report: bool,
queue: Optional[str],
debuggable: bool,
parent_env: Optional[weakref.ReferenceType[TaskEnvironment]],
parent_env_name: Optional[str],
max_inline_io_bytes: int,
triggers: Tuple[Trigger, ...],
links: Tuple[Link, ...],
_call_as_synchronous: bool,
func: F,
plugin_config: Optional[SandboxedConfig],
task_resolver: Optional[Any],
_user_source: str,
_user_input_names: List[str],
_user_functions: Dict[str, Any],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `interface` | `NativeInterface` | |
| `short_name` | `str` | |
| `task_type` | `str` | |
| `task_type_version` | `int` | |
| `image` | `Union[str, Image, Literal['auto']] \| None` | |
| `resources` | `Optional[Resources]` | |
| `cache` | `CacheRequest` | |
| `interruptible` | `bool` | |
| `retries` | `Union[int, RetryStrategy]` | |
| `reusable` | `Union[ReusePolicy, None]` | |
| `docs` | `Optional[Documentation]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `timeout` | `Optional[TimeoutType]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `report` | `bool` | |
| `queue` | `Optional[str]` | |
| `debuggable` | `bool` | |
| `parent_env` | `Optional[weakref.ReferenceType[TaskEnvironment]]` | |
| `parent_env_name` | `Optional[str]` | |
| `max_inline_io_bytes` | `int` | |
| `triggers` | `Tuple[Trigger, ...]` | |
| `links` | `Tuple[Link, ...]` | |
| `_call_as_synchronous` | `bool` | |
| `func` | `F` | |
| `plugin_config` | `Optional[SandboxedConfig]` | |
| `task_resolver` | `Optional[Any]` | |
| `_user_source` | `str` | |
| `_user_input_names` | `List[str]` | |
| `_user_functions` | `Dict[str, Any]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `json_schema` | `None` | JSON schema for the task inputs, following the Flyte standard. Delegates to NativeInterface.json_schema, which uses the type engine to produce a LiteralType per input and converts to JSON schema. |
| `native_interface` | `None` | |
| `source_file` | `None` | Returns the source file of the function, if available. This is useful for debugging and tracing. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > container_args()** | Returns the container args for the task. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > custom_config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > execute()** | Execute the function body in a Monty sandbox. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > forward()** | Not supported β there is no Python function to call directly. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > override()** | Override various parameters of the task template. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > post()** | This is the postexecute function that will be. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > pre()** | This is the preexecute function that will be. |
| **Flyte SDK > Packages > flyte.sandbox > CodeTaskTemplate > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
args,
kwargs,
) -> Any
```
Execute the function body in a Monty sandbox.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args,
kwargs,
) -> Any
```
Not supported β there is no Python function to call directly.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: SerializationContext,
) -> Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.sandbox/imageconfig ===
# ImageConfig
**Package:** `flyte.sandbox`
Configuration for Docker image building at runtime.
## Parameters
```python
class ImageConfig(
registry: typing.Optional[str],
registry_secret: typing.Optional[str],
python_version: typing.Optional[tuple[int, int]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `registry` | `typing.Optional[str]` | |
| `registry_secret` | `typing.Optional[str]` | |
| `python_version` | `typing.Optional[tuple[int, int]]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.sandbox/sandboxedconfig ===
# SandboxedConfig
**Package:** `flyte.sandbox`
Configuration for a sandboxed task executed via Monty.
## Parameters
```python
class SandboxedConfig(
max_memory: int,
max_stack_depth: int,
timeout_ms: int,
type_check: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `max_memory` | `int` | |
| `max_stack_depth` | `int` | |
| `timeout_ms` | `int` | |
| `type_check` | `bool` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.sandbox/sandboxedtasktemplate ===
# SandboxedTaskTemplate
**Package:** `flyte.sandbox`
A task template that executes the function body in a Monty sandbox.
For pure Python functions (no external calls), Monty executes the
entire body without pausing. For functions that call other tasks or
durable operations, `run_monty_async` handles async dispatch.
## Parameters
```python
class SandboxedTaskTemplate(
name: str,
interface: NativeInterface,
short_name: str,
task_type: str,
task_type_version: int,
image: Union[str, Image, Literal['auto']] | None,
resources: Optional[Resources],
cache: CacheRequest,
interruptible: bool,
retries: Union[int, RetryStrategy],
reusable: Union[ReusePolicy, None],
docs: Optional[Documentation],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
timeout: Optional[TimeoutType],
pod_template: Optional[Union[str, PodTemplate]],
report: bool,
queue: Optional[str],
debuggable: bool,
parent_env: Optional[weakref.ReferenceType[TaskEnvironment]],
parent_env_name: Optional[str],
max_inline_io_bytes: int,
triggers: Tuple[Trigger, ...],
links: Tuple[Link, ...],
_call_as_synchronous: bool,
func: F,
plugin_config: Optional[SandboxedConfig],
task_resolver: Optional[Any],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `interface` | `NativeInterface` | |
| `short_name` | `str` | |
| `task_type` | `str` | |
| `task_type_version` | `int` | |
| `image` | `Union[str, Image, Literal['auto']] \| None` | |
| `resources` | `Optional[Resources]` | |
| `cache` | `CacheRequest` | |
| `interruptible` | `bool` | |
| `retries` | `Union[int, RetryStrategy]` | |
| `reusable` | `Union[ReusePolicy, None]` | |
| `docs` | `Optional[Documentation]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `timeout` | `Optional[TimeoutType]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `report` | `bool` | |
| `queue` | `Optional[str]` | |
| `debuggable` | `bool` | |
| `parent_env` | `Optional[weakref.ReferenceType[TaskEnvironment]]` | |
| `parent_env_name` | `Optional[str]` | |
| `max_inline_io_bytes` | `int` | |
| `triggers` | `Tuple[Trigger, ...]` | |
| `links` | `Tuple[Link, ...]` | |
| `_call_as_synchronous` | `bool` | |
| `func` | `F` | |
| `plugin_config` | `Optional[SandboxedConfig]` | |
| `task_resolver` | `Optional[Any]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `json_schema` | `None` | JSON schema for the task inputs, following the Flyte standard. Delegates to NativeInterface.json_schema, which uses the type engine to produce a LiteralType per input and converts to JSON schema. |
| `native_interface` | `None` | |
| `source_file` | `None` | Returns the source file of the function, if available. This is useful for debugging and tracing. |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > container_args()** | Returns the container args for the task. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > custom_config()** | Returns additional configuration for the task. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > execute()** | Execute the function body in a Monty sandbox. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > forward()** | Bypass Monty and call the function directly (for local/debug execution). |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > override()** | Override various parameters of the task template. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > post()** | This is the postexecute function that will be. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > pre()** | This is the preexecute function that will be. |
| **Flyte SDK > Packages > flyte.sandbox > SandboxedTaskTemplate > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
serialize_context: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
args,
kwargs,
) -> Any
```
Execute the function body in a Monty sandbox.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args,
kwargs,
) -> Any
```
Bypass Monty and call the function directly (for local/debug execution).
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: SerializationContext,
) -> Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage ===
# flyte.storage
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > ABFS** | Any Azure Blob Storage specific configuration. |
| **Flyte SDK > Packages > flyte.storage > GCS** | Any GCS specific configuration. |
| **Flyte SDK > Packages > flyte.storage > S3** | S3 specific configuration. |
| **Flyte SDK > Packages > flyte.storage > Storage** | Data storage configuration that applies across any provider. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > Methods > exists()** | Check if a path exists. |
| **Flyte SDK > Packages > flyte.storage > Methods > exists_sync()** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get()** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get_configured_fsspec_kwargs()** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get_random_local_directory()** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get_random_local_path()** | Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name. |
| **Flyte SDK > Packages > flyte.storage > Methods > get_stream()** | Get a stream of data from a remote location. |
| **Flyte SDK > Packages > flyte.storage > Methods > get_underlying_filesystem()** | |
| **Flyte SDK > Packages > flyte.storage > Methods > is_remote()** | Let's find a replacement. |
| **Flyte SDK > Packages > flyte.storage > Methods > join()** | Join multiple paths together. |
| **Flyte SDK > Packages > flyte.storage > open()** | Asynchronously open a file and return an async context manager. |
| **Flyte SDK > Packages > flyte.storage > put()** | |
| **Flyte SDK > Packages > flyte.storage > put_stream()** | Put a stream of data to a remote location. |
## Methods
#### exists()
```python
def exists(
path: str,
kwargs,
) -> bool
```
Check if a path exists.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to be checked. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. |
**Returns:** True if the path exists, False otherwise.
#### exists_sync()
```python
def exists_sync(
path: str,
kwargs,
) -> bool
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `kwargs` | `**kwargs` | |
#### get()
```python
def get(
from_path: str,
to_path: Optional[str | pathlib.Path],
recursive: bool,
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str \| pathlib.Path]` | |
| `recursive` | `bool` | |
| `kwargs` | `**kwargs` | |
#### get_configured_fsspec_kwargs()
```python
def get_configured_fsspec_kwargs(
protocol: typing.Optional[str],
anonymous: bool,
) -> typing.Dict[str, typing.Any]
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
#### get_random_local_directory()
```python
def get_random_local_directory()
```
**Returns:** pathlib.Path
#### get_random_local_path()
```python
def get_random_local_path(
file_path_or_file_name: pathlib.Path | str | None,
) -> pathlib.Path
```
Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name
| Parameter | Type | Description |
|-|-|-|
| `file_path_or_file_name` | `pathlib.Path \| str \| None` | |
#### get_stream()
```python
def get_stream(
path: str,
chunk_size,
kwargs,
) -> AsyncGenerator[bytes, None]
```
Get a stream of data from a remote location.
This is useful for downloading streaming data from a remote location.
Example usage:
```python
import flyte.storage as storage
async for chunk in storage.get_stream(path="s3://my_bucket/my_file.txt"):
process(chunk)
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to the remote location where the data will be downloaded. |
| `chunk_size` | | Size of each chunk to be read from the file. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. |
**Returns:** An async iterator that yields chunks of bytes.
#### get_underlying_filesystem()
```python
def get_underlying_filesystem(
protocol: typing.Optional[str],
anonymous: bool,
path: typing.Optional[str],
kwargs,
) -> fsspec.AbstractFileSystem
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
| `path` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
#### is_remote()
```python
def is_remote(
path: typing.Union[pathlib.Path | str],
) -> bool
```
Let's find a replacement
| Parameter | Type | Description |
|-|-|-|
| `path` | `typing.Union[pathlib.Path \| str]` | |
#### join()
```python
def join(
paths: str,
) -> str
```
Join multiple paths together. This is a wrapper around os.path.join.
# TODO replace with proper join with fsspec root etc
| Parameter | Type | Description |
|-|-|-|
| `paths` | `str` | Paths to be joined. |
#### open()
```python
def open(
path: str,
mode: str,
kwargs,
) -> AsyncReadableFile | AsyncWritableFile
```
Asynchronously open a file and return an async context manager.
This function checks if the underlying filesystem supports obstore bypass.
If it does, it uses obstore to open the file. Otherwise, it falls back to
the standard _open function which uses AsyncFileSystem.
It will raise NotImplementedError if neither obstore nor AsyncFileSystem is supported.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `mode` | `str` | |
| `kwargs` | `**kwargs` | |
#### put()
```python
def put(
from_path: str,
to_path: Optional[str],
recursive: bool,
batch_size: Optional[int],
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str]` | |
| `recursive` | `bool` | |
| `batch_size` | `Optional[int]` | |
| `kwargs` | `**kwargs` | |
#### put_stream()
```python
def put_stream(
data_iterable: typing.AsyncIterable[bytes] | bytes,
name: str | None,
to_path: str | None,
kwargs,
) -> str
```
Put a stream of data to a remote location. This is useful for streaming data to a remote location.
Example usage:
```python
import flyte.storage as storage
storage.put_stream(iter([b'hello']), name="my_file.txt")
OR
storage.put_stream(iter([b'hello']), to_path="s3://my_bucket/my_file.txt")
```
| Parameter | Type | Description |
|-|-|-|
| `data_iterable` | `typing.AsyncIterable[bytes] \| bytes` | Iterable of bytes to be streamed. |
| `name` | `str \| None` | Name of the file to be created. If not provided, a random name will be generated. |
| `to_path` | `str \| None` | Path to the remote location where the data will be stored. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. |
**Returns:** The path to the remote location where the data was stored.
## Subpages
- **Flyte SDK > Packages > flyte.storage > ABFS**
- **Flyte SDK > Packages > flyte.storage > GCS**
- **Flyte SDK > Packages > flyte.storage > S3**
- **Flyte SDK > Packages > flyte.storage > Storage**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage/abfs ===
# ABFS
**Package:** `flyte.storage`
Any Azure Blob Storage specific configuration.
## Parameters
```python
class ABFS(
retries: int,
backoff: datetime.timedelta,
enable_debug: bool,
attach_execution_metadata: bool,
account_name: typing.Optional[str],
account_key: typing.Optional[str],
tenant_id: typing.Optional[str],
client_id: typing.Optional[str],
client_secret: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `retries` | `int` | |
| `backoff` | `datetime.timedelta` | |
| `enable_debug` | `bool` | |
| `attach_execution_metadata` | `bool` | |
| `account_name` | `typing.Optional[str]` | |
| `account_key` | `typing.Optional[str]` | |
| `tenant_id` | `typing.Optional[str]` | |
| `client_id` | `typing.Optional[str]` | |
| `client_secret` | `typing.Optional[str]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > ABFS > Methods > auto()** | Construct the config object automatically from environment variables. |
| **Flyte SDK > Packages > flyte.storage > ABFS > Methods > get_fsspec_kwargs()** | Returns the configuration as kwargs for constructing an fsspec filesystem. |
### auto()
```python
def auto()
```
Construct the config object automatically from environment variables.
### get_fsspec_kwargs()
```python
def get_fsspec_kwargs(
anonymous: bool,
kwargs,
) -> typing.Dict[str, typing.Any]
```
Returns the configuration as kwargs for constructing an fsspec filesystem.
| Parameter | Type | Description |
|-|-|-|
| `anonymous` | `bool` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage/gcs ===
# GCS
**Package:** `flyte.storage`
Any GCS specific configuration.
## Parameters
```python
class GCS(
retries: int,
backoff: datetime.timedelta,
enable_debug: bool,
attach_execution_metadata: bool,
gsutil_parallelism: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `retries` | `int` | |
| `backoff` | `datetime.timedelta` | |
| `enable_debug` | `bool` | |
| `attach_execution_metadata` | `bool` | |
| `gsutil_parallelism` | `bool` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > GCS > Methods > auto()** | Construct the config object automatically from environment variables. |
| **Flyte SDK > Packages > flyte.storage > GCS > Methods > get_fsspec_kwargs()** | Returns the configuration as kwargs for constructing an fsspec filesystem. |
### auto()
```python
def auto()
```
Construct the config object automatically from environment variables.
### get_fsspec_kwargs()
```python
def get_fsspec_kwargs(
anonymous: bool,
kwargs,
) -> typing.Dict[str, typing.Any]
```
Returns the configuration as kwargs for constructing an fsspec filesystem.
| Parameter | Type | Description |
|-|-|-|
| `anonymous` | `bool` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage/s3 ===
# S3
**Package:** `flyte.storage`
S3 specific configuration.
Authentication resolution used by Flyte + obstore:
1. If explicit static credentials are provided via Flyte S3 inputs/environment
(`access_key_id`/`secret_access_key`), those are used.
2. If static credentials are not provided, and both `AWS_PROFILE` and
`AWS_CONFIG_FILE` are available, Flyte configures a boto3-backed obstore
credential provider so profile-based auth can be used. This requires that the `boto3` library
is installed.
3. If neither of the above applies, obstore uses the default AWS credential chain
(for remote runs this commonly resolves via workload identity / IAM attached to
the service account and then IMDS fallbacks where applicable).
## Parameters
```python
class S3(
retries: int,
backoff: datetime.timedelta,
enable_debug: bool,
attach_execution_metadata: bool,
endpoint: typing.Optional[str],
access_key_id: typing.Optional[str],
secret_access_key: typing.Optional[str],
region: typing.Optional[str],
addressing_style: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `retries` | `int` | |
| `backoff` | `datetime.timedelta` | |
| `enable_debug` | `bool` | |
| `attach_execution_metadata` | `bool` | |
| `endpoint` | `typing.Optional[str]` | |
| `access_key_id` | `typing.Optional[str]` | |
| `secret_access_key` | `typing.Optional[str]` | |
| `region` | `typing.Optional[str]` | |
| `addressing_style` | `typing.Optional[str]` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > S3 > Methods > auto()** | |
| **Flyte SDK > Packages > flyte.storage > S3 > Methods > for_sandbox()** | |
| **Flyte SDK > Packages > flyte.storage > S3 > Methods > get_fsspec_kwargs()** | Returns the configuration as kwargs for constructing an fsspec filesystem. |
### auto()
```python
def auto(
region: str | None,
) -> S3
```
| Parameter | Type | Description |
|-|-|-|
| `region` | `str \| None` | |
**Returns:** Config
### for_sandbox()
```python
def for_sandbox()
```
### get_fsspec_kwargs()
```python
def get_fsspec_kwargs(
anonymous: bool,
kwargs,
) -> typing.Dict[str, typing.Any]
```
Returns the configuration as kwargs for constructing an fsspec filesystem.
| Parameter | Type | Description |
|-|-|-|
| `anonymous` | `bool` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage/storage ===
# Storage
**Package:** `flyte.storage`
Data storage configuration that applies across any provider.
## Parameters
```python
class Storage(
retries: int,
backoff: datetime.timedelta,
enable_debug: bool,
attach_execution_metadata: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `retries` | `int` | |
| `backoff` | `datetime.timedelta` | |
| `enable_debug` | `bool` | |
| `attach_execution_metadata` | `bool` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > Storage > Methods > auto()** | Construct the config object automatically from environment variables. |
| **Flyte SDK > Packages > flyte.storage > Storage > Methods > get_fsspec_kwargs()** | Returns the configuration as kwargs for constructing an fsspec filesystem. |
### auto()
```python
def auto()
```
Construct the config object automatically from environment variables.
### get_fsspec_kwargs()
```python
def get_fsspec_kwargs(
anonymous: bool,
kwargs,
) -> typing.Dict[str, typing.Any]
```
Returns the configuration as kwargs for constructing an fsspec filesystem.
| Parameter | Type | Description |
|-|-|-|
| `anonymous` | `bool` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.syncify ===
# flyte.syncify
# Syncify Module
This module provides the `syncify` decorator and the `Syncify` class.
The decorator can be used to convert asynchronous functions or methods into synchronous ones.
This is useful for integrating async code into synchronous contexts.
Every asynchronous function or method wrapped with `syncify` can be called synchronously using the
parenthesis `()` operator, or asynchronously using the `.aio()` method.
Example::
```python
from flyte.syncify import syncify
@syncify
async def async_function(x: str) -> str:
return f"Hello, Async World {x}!"
# now you can call it synchronously
result = async_function("Async World") # Note: no .aio() needed for sync calls
print(result)
# Output: Hello, Async World Async World!
# or call it asynchronously
async def main():
result = await async_function.aio("World") # Note the use of .aio() for async calls
print(result)
```
## Creating a Syncify Instance
```python
from flyte.syncify. import Syncify
syncer = Syncify("my_syncer")
# Now you can use `syncer` to decorate your async functions or methods
```
## How does it work?
The Syncify class wraps asynchronous functions, classmethods, instance methods, and static methods to
provide a synchronous interface. The wrapped methods are always executed in the context of a background loop,
whether they are called synchronously or asynchronously. This allows for seamless integration of async code, as
certain async libraries capture the event loop. An example is grpc.aio, which captures the event loop.
In such a case, the Syncify class ensures that the async function is executed in the context of the background loop.
To use it correctly with grpc.aio, you should wrap every grpc.aio channel creation, and client invocation
with the same `Syncify` instance. This ensures that the async code runs in the correct event loop context.
## Directory
### Classes
| Class | Description |
|-|-|
| [`Syncify`](syncify/page.md) | A decorator to convert asynchronous functions or methods into synchronous ones. |
## Subpages
- **Flyte SDK > Packages > flyte.syncify > Syncify**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.syncify/syncify ===
# Syncify
**Package:** `flyte.syncify`
A decorator to convert asynchronous functions or methods into synchronous ones.
This is useful for integrating async code into synchronous contexts.
Example::
```python
syncer = Syncify()
@syncer
async def async_function(x: str) -> str:
return f"Hello, Async World {x}!"
# now you can call it synchronously
result = async_function("Async World")
print(result)
# Output: Hello, Async World Async World!
# or call it asynchronously
async def main():
result = await async_function.aio("World")
print(result)
```
## Parameters
```python
class Syncify(
name: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types ===
# flyte.types
# Flyte Type System
The Flyte type system provides a way to define, transform, and manipulate types in Flyte workflows.
Since the data flowing through Flyte has to often cross process, container and langauge boundaries, the type system
is designed to be serializable to a universal format that can be understood across different environments. This
universal format is based on Protocol Buffers. The types are called LiteralTypes and the runtime
representation of data is called Literals.
The type system includes:
- **TypeEngine**: The core engine that manages type transformations and serialization. This is the main entry point for
for all the internal type transformations and serialization logic.
- **TypeTransformer**: A class that defines how to transform one type to another. This is extensible
allowing users to define custom types and transformations.
- **Renderable**: An interface for types that can be rendered as HTML, that can be outputted to a flyte.report.
It is always possible to bypass the type system and use the `FlytePickle` type to serialize any python object
into a pickle format. The pickle format is not human-readable, but can be passed between flyte tasks that are
written in python. The Pickled objects cannot be represented in the UI, and may be in-efficient for large datasets.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > FlytePickle** | This type is only used by flytekit internally. |
| **Flyte SDK > Packages > flyte.types > TypeEngine** | Core Extensible TypeEngine of Flytekit. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer** | Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > Renderable** | |
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > TypeTransformerFailedError** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > Methods > guess_interface()** | Returns the interface of the task with guessed types, as types may not be present in current env. |
| **Flyte SDK > Packages > flyte.types > Methods > literal_string_repr()** | This method is used to convert a literal map to a string representation. |
## Methods
#### guess_interface()
```python
def guess_interface(
interface: flyteidl2.core.interface_pb2.TypedInterface,
default_inputs: typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]],
) -> flyte.models.NativeInterface
```
Returns the interface of the task with guessed types, as types may not be present in current env.
| Parameter | Type | Description |
|-|-|-|
| `interface` | `flyteidl2.core.interface_pb2.TypedInterface` | |
| `default_inputs` | `typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]]` | |
#### literal_string_repr()
```python
def literal_string_repr(
lm: typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]],
) -> typing.Dict[str, typing.Any]
```
This method is used to convert a literal map to a string representation.
| Parameter | Type | Description |
|-|-|-|
| `lm` | `typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]]` | |
## Subpages
- **Flyte SDK > Packages > flyte.types > FlytePickle**
- **Flyte SDK > Packages > flyte.types > Renderable**
- **Flyte SDK > Packages > flyte.types > TypeEngine**
- **Flyte SDK > Packages > flyte.types > TypeTransformer**
- **Flyte SDK > Packages > flyte.types > TypeTransformerFailedError**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types/flytepickle ===
# FlytePickle
**Package:** `flyte.types`
This type is only used by flytekit internally. User should not use this type.
Any type that flyte can't recognize will become FlytePickle
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > FlytePickle > Methods > from_pickle()** | |
| **Flyte SDK > Packages > flyte.types > FlytePickle > Methods > python_type()** | |
| **Flyte SDK > Packages > flyte.types > FlytePickle > Methods > to_pickle()** | |
### from_pickle()
```python
def from_pickle(
uri: str,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `uri` | `str` | |
### python_type()
```python
def python_type()
```
### to_pickle()
```python
def to_pickle(
python_val: typing.Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `typing.Any` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types/renderable ===
# Renderable
**Package:** `flyte.types`
```python
protocol Renderable()
```
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > Renderable > Methods > to_html()** | Convert an object(markdown, pandas. |
### to_html()
```python
def to_html(
python_value: typing.Any,
) -> str
```
Convert an object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string.
Returns: An HTML document as a string.
| Parameter | Type | Description |
|-|-|-|
| `python_value` | `typing.Any` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types/typeengine ===
# TypeEngine
**Package:** `flyte.types`
Core Extensible TypeEngine of Flytekit. This should be used to extend the capabilities of FlyteKits type system.
Users can implement their own TypeTransformers and register them with the TypeEngine. This will allow special
handling
of user objects
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > dict_to_literal_map()** | Given a dictionary mapping string keys to python values and a dictionary containing guessed types for such. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > get_available_transformers()** | Returns all python types for which transformers are available. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > get_transformer()** | Implements a recursive search for the transformer. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > guess_python_type()** | Transforms a flyte-specific `LiteralType` to a regular python value. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > guess_python_types()** | Transforms a list of flyte-specific `VariableEntry` objects to a dictionary of regular python values. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > lazy_import_transformers()** | Only load the transformers if needed. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > literal_map_to_kwargs()** | Given a `LiteralMap` (usually an input into a task - intermediate), convert to kwargs for the task. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > named_tuple_to_variable_map()** | Converts a python-native `NamedTuple` to a flyte-specific VariableMap of named literals. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > register()** | This should be used for all types that respond with the right type annotation when you use type(. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > register_additional_type()** | |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > register_restricted_type()** | |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > to_html()** | |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > to_literal()** | |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > to_literal_checks()** | |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > to_literal_type()** | Converts a python type into a flyte specific `LiteralType`. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > to_python_value()** | Converts a Literal value with an expected python type into a python value. |
| **Flyte SDK > Packages > flyte.types > TypeEngine > Methods > unwrap_offloaded_literal()** | |
### dict_to_literal_map()
```python
def dict_to_literal_map(
d: typing.Dict[str, typing.Any],
type_hints: Optional[typing.Dict[str, type]],
) -> LiteralMap
```
Given a dictionary mapping string keys to python values and a dictionary containing guessed types for such
string keys,
convert to a LiteralMap.
| Parameter | Type | Description |
|-|-|-|
| `d` | `typing.Dict[str, typing.Any]` | |
| `type_hints` | `Optional[typing.Dict[str, type]]` | |
### get_available_transformers()
```python
def get_available_transformers()
```
Returns all python types for which transformers are available
### get_transformer()
```python
def get_transformer(
python_type: Type,
) -> TypeTransformer
```
Implements a recursive search for the transformer.
| Parameter | Type | Description |
|-|-|-|
| `python_type` | `Type` | |
### guess_python_type()
```python
def guess_python_type(
flyte_type: LiteralType,
) -> Type[T]
```
Transforms a flyte-specific `LiteralType` to a regular python value.
| Parameter | Type | Description |
|-|-|-|
| `flyte_type` | `LiteralType` | |
### guess_python_types()
```python
def guess_python_types(
flyte_variable_list: typing.List[interface_pb2.VariableEntry],
) -> typing.Dict[str, Type[Any]]
```
Transforms a list of flyte-specific `VariableEntry` objects to a dictionary of regular python values.
| Parameter | Type | Description |
|-|-|-|
| `flyte_variable_list` | `typing.List[interface_pb2.VariableEntry]` | |
### lazy_import_transformers()
```python
def lazy_import_transformers()
```
Only load the transformers if needed.
### literal_map_to_kwargs()
```python
def literal_map_to_kwargs(
lm: LiteralMap,
python_types: typing.Optional[typing.Dict[str, type]],
literal_types: typing.Optional[typing.Dict[str, interface_pb2.Variable]],
) -> typing.Dict[str, typing.Any]
```
Given a `LiteralMap` (usually an input into a task - intermediate), convert to kwargs for the task
| Parameter | Type | Description |
|-|-|-|
| `lm` | `LiteralMap` | |
| `python_types` | `typing.Optional[typing.Dict[str, type]]` | |
| `literal_types` | `typing.Optional[typing.Dict[str, interface_pb2.Variable]]` | |
### named_tuple_to_variable_map()
```python
def named_tuple_to_variable_map(
t: typing.NamedTuple,
) -> interface_pb2.VariableMap
```
Converts a python-native `NamedTuple` to a flyte-specific VariableMap of named literals.
| Parameter | Type | Description |
|-|-|-|
| `t` | `typing.NamedTuple` | |
### register()
```python
def register(
transformer: TypeTransformer,
additional_types: Optional[typing.List[Type]],
)
```
This should be used for all types that respond with the right type annotation when you use type(...) function
| Parameter | Type | Description |
|-|-|-|
| `transformer` | `TypeTransformer` | |
| `additional_types` | `Optional[typing.List[Type]]` | |
### register_additional_type()
```python
def register_additional_type(
transformer: TypeTransformer[T],
additional_type: Type[T],
override,
)
```
| Parameter | Type | Description |
|-|-|-|
| `transformer` | `TypeTransformer[T]` | |
| `additional_type` | `Type[T]` | |
| `override` | | |
### register_restricted_type()
```python
def register_restricted_type(
name: str,
type: Type[T],
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `type` | `Type[T]` | |
### to_html()
```python
def to_html(
python_val: typing.Any,
expected_python_type: Type[typing.Any],
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `typing.Any` | |
| `expected_python_type` | `Type[typing.Any]` | |
### to_literal()
```python
def to_literal(
python_val: typing.Any,
python_type: Type[T],
expected: types_pb2.LiteralType,
) -> literals_pb2.Literal
```
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `typing.Any` | |
| `python_type` | `Type[T]` | |
| `expected` | `types_pb2.LiteralType` | |
### to_literal_checks()
```python
def to_literal_checks(
python_val: typing.Any,
python_type: Type[T],
expected: LiteralType,
)
```
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `typing.Any` | |
| `python_type` | `Type[T]` | |
| `expected` | `LiteralType` | |
### to_literal_type()
```python
def to_literal_type(
python_type: Type[T],
) -> LiteralType
```
Converts a python type into a flyte specific `LiteralType`
| Parameter | Type | Description |
|-|-|-|
| `python_type` | `Type[T]` | |
### to_python_value()
```python
def to_python_value(
lv: Literal,
expected_python_type: Type,
) -> typing.Any
```
Converts a Literal value with an expected python type into a python value.
| Parameter | Type | Description |
|-|-|-|
| `lv` | `Literal` | |
| `expected_python_type` | `Type` | |
### unwrap_offloaded_literal()
```python
def unwrap_offloaded_literal(
lv: literals_pb2.Literal,
) -> literals_pb2.Literal
```
| Parameter | Type | Description |
|-|-|-|
| `lv` | `literals_pb2.Literal` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types/typetransformer ===
# TypeTransformer
**Package:** `flyte.types`
Base transformer type that should be implemented for every python native type that can be handled by flytekit
## Parameters
```python
class TypeTransformer(
name: str,
t: Type[T],
enable_type_assertions: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `t` | `Type[T]` | |
| `enable_type_assertions` | `bool` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `name` | `None` | |
| `python_type` | `None` | This returns the python type |
| `type_assertions_enabled` | `None` | Indicates if the transformer wants type assertions to be enabled at the core type engine layer |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > assert_type()** | |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > from_binary_idl()** | This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > get_literal_type()** | Converts the python type to a Flyte LiteralType. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > guess_python_type()** | Converts the Flyte LiteralType to a python object type. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > isinstance_generic()** | |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > schema_match()** | Check if a JSON schema fragment matches this transformer's python_type. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > to_html()** | Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > to_literal()** | Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type. |
| **Flyte SDK > Packages > flyte.types > TypeTransformer > Methods > to_python_value()** | Converts the given Literal to a Python Type. |
### assert_type()
```python
def assert_type(
t: Type[T],
v: T,
)
```
| Parameter | Type | Description |
|-|-|-|
| `t` | `Type[T]` | |
| `v` | `T` | |
### from_binary_idl()
```python
def from_binary_idl(
binary_idl_object: Binary,
expected_python_type: Type[T],
) -> Optional[T]
```
This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and
attribute access.
For untyped dict, dataclass, and pydantic basemodel:
Life Cycle (Untyped Dict as example):
python val -> msgpack bytes -> binary literal scalar -> msgpack bytes -> python val
(to_literal) (from_binary_idl)
For attribute access:
Life Cycle:
python val -> msgpack bytes -> binary literal scalar -> resolved golang value -> binary literal scalar
-> msgpack bytes -> python val
(to_literal) (propeller attribute access) (from_binary_idl)
| Parameter | Type | Description |
|-|-|-|
| `binary_idl_object` | `Binary` | |
| `expected_python_type` | `Type[T]` | |
### get_literal_type()
```python
def get_literal_type(
t: Type[T],
) -> LiteralType
```
Converts the python type to a Flyte LiteralType
| Parameter | Type | Description |
|-|-|-|
| `t` | `Type[T]` | |
### guess_python_type()
```python
def guess_python_type(
literal_type: LiteralType,
) -> Type[T]
```
Converts the Flyte LiteralType to a python object type.
| Parameter | Type | Description |
|-|-|-|
| `literal_type` | `LiteralType` | |
### isinstance_generic()
```python
def isinstance_generic(
obj,
generic_alias,
)
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | | |
| `generic_alias` | | |
### schema_match()
```python
def schema_match(
schema: dict,
) -> bool
```
Check if a JSON schema fragment matches this transformer's python_type.
For BaseModel subclasses, automatically compares the schema's title, type, and
required fields against the type's own JSON schema. For other types, returns
False by default β override if needed.
| Parameter | Type | Description |
|-|-|-|
| `schema` | `dict` | |
### to_html()
```python
def to_html(
python_val: T,
expected_python_type: Type[T],
) -> str
```
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `T` | |
| `expected_python_type` | `Type[T]` | |
### to_literal()
```python
def to_literal(
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal
```
Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type.
Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these
do not match (or are not allowed) the Transformer implementer should raise an AssertionError, clearly stating
what was the mismatch
| Parameter | Type | Description |
|-|-|-|
| `python_val` | `T` | The actual value to be transformed |
| `python_type` | `Type[T]` | The assumed type of the value (this matches the declared type on the function) |
| `expected` | `LiteralType` | Expected Literal Type |
### to_python_value()
```python
def to_python_value(
lv: Literal,
expected_python_type: Type[T],
) -> Optional[T]
```
Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised
| Parameter | Type | Description |
|-|-|-|
| `lv` | `Literal` | The received literal Value |
| `expected_python_type` | `Type[T]` | Expected native python type that should be returned |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types/typetransformerfailederror ===
# TypeTransformerFailedError
**Package:** `flyte.types`
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations ===
# Integrations
API reference for Flyte integration plugins.
## Subpages
- **Integrations > Anthropic**
- **Integrations > BigQuery**
- **Integrations > Code generation**
- **Integrations > Dask**
- **Integrations > Databricks**
- **Integrations > Gemini**
- **Integrations > Human-in-the-Loop**
- **Integrations > JSONL**
- **Integrations > MLflow**
- **Integrations > OpenAI**
- **Integrations > Polars**
- **Integrations > PyTorch**
- **Integrations > Ray**
- **Integrations > SGLang**
- **Integrations > Snowflake**
- **Integrations > Spark**
- **Integrations > Union**
- **Integrations > vLLM**
- **Integrations > Weights & Biases**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/anthropic ===
# Anthropic
## Subpages
- **Integrations > Anthropic > Classes**
- **Integrations > Anthropic > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/anthropic/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Anthropic > Packages > flyteplugins.anthropic > Agent** |A Claude agent configuration. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/anthropic/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Anthropic > Packages > flyteplugins.anthropic** | Anthropic Claude plugin for Flyte. |
## Subpages
- **Integrations > Anthropic > Packages > flyteplugins.anthropic**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/anthropic/packages/flyteplugins.anthropic ===
# flyteplugins.anthropic
Anthropic Claude plugin for Flyte.
This plugin provides integration between Flyte tasks and Anthropic's Claude API,
enabling you to use Flyte tasks as tools for Claude agents. Tool calls run with
full Flyte observability, retries, and caching.
Key features:
- Use any Flyte task as a Claude tool via `function_tool`
- Full agent loop with automatic tool dispatch via `run_agent`
- Configurable agent via `Agent` (model, system prompt, tools, iteration limits)
Basic usage example:
```python
import flyte
from flyteplugins.anthropic import Agent, function_tool, run_agent
env = flyte.TaskEnvironment(
name="agent_env",
image=flyte.Image.from_debian_base(name="agent").with_pip_packages(
"flyteplugins-anthropic"
),
)
@env.task
async def get_weather(city: str) -> str:
'''Get the current weather for a city.'''
return f"Weather in {city}: sunny, 22Β°C"
weather_tool = function_tool(get_weather)
@env.task
async def run_weather_agent(question: str) -> str:
return await run_agent(
prompt=question,
tools=[weather_tool],
model="claude-sonnet-4-20250514",
)
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Anthropic > Packages > flyteplugins.anthropic > Agent** | A Claude agent configuration. |
### Methods
| Method | Description |
|-|-|
| **Integrations > Anthropic > Packages > flyteplugins.anthropic > Methods > function_tool()** | Convert a function or Flyte task to an Anthropic-compatible tool. |
| **Integrations > Anthropic > Packages > flyteplugins.anthropic > Methods > run_agent()** | Run a Claude agent with the given tools and prompt. |
## Methods
#### function_tool()
```python
def function_tool(
func: typing.Union[flyte._task.AsyncFunctionTaskTemplate, typing.Callable, NoneType],
name: str | None,
description: str | None,
) -> FunctionTool | partial[FunctionTool]
```
Convert a function or Flyte task to an Anthropic-compatible tool.
This function converts a Python function, @flyte.trace decorated function,
or Flyte task into a FunctionTool that can be used with Claude's tool use API.
The input_schema is derived via the Flyte type engine, producing JSON schema
This ensures that Literal types, dataclasses, FlyteFile, and other Flyte-native
types are represented correctly.
For @flyte.trace decorated functions, the tracing context is preserved
automatically since functools.wraps maintains the original function's metadata.
Example:
```python
@env.task
async def get_weather(city: str) -> str:
'''Get the current weather for a city.'''
return f"Weather in {city}: sunny"
tool = function_tool(get_weather)
```
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Union[flyte._task.AsyncFunctionTaskTemplate, typing.Callable, NoneType]` | The function or Flyte task to convert. |
| `name` | `str \| None` | Optional custom name for the tool. Defaults to the function name. |
| `description` | `str \| None` | Optional custom description. Defaults to the function's docstring. |
**Returns**
A FunctionTool instance that can be used with run_agent().
#### run_agent()
```python
def run_agent(
prompt: str,
tools: list[flyteplugins.anthropic.agents._function_tools.FunctionTool] | None,
agent: flyteplugins.anthropic.agents._function_tools.Agent | None,
model: str,
system: str | None,
max_tokens: int,
max_iterations: int,
api_key: str | None,
) -> str
```
Run a Claude agent with the given tools and prompt.
This function creates a Claude conversation loop that can use tools
to accomplish tasks. It handles the back-and-forth of tool calls
and responses until the agent produces a final text response.
Example:
```python
result = await run_agent(
prompt="What's the weather in SF?",
tools=[function_tool(get_weather)],
)
```
| Parameter | Type | Description |
|-|-|-|
| `prompt` | `str` | The user prompt to send to the agent. |
| `tools` | `list[flyteplugins.anthropic.agents._function_tools.FunctionTool] \| None` | List of FunctionTool instances to make available to the agent. |
| `agent` | `flyteplugins.anthropic.agents._function_tools.Agent \| None` | Optional Agent configuration. If provided, overrides other params. |
| `model` | `str` | The Claude model to use. |
| `system` | `str \| None` | Optional system prompt. |
| `max_tokens` | `int` | Maximum tokens in the response. |
| `max_iterations` | `int` | Maximum number of tool call iterations. |
| `api_key` | `str \| None` | Anthropic API key. Defaults to ANTHROPIC_API_KEY env var. |
**Returns**
The final text response from the agent.
## Subpages
- **Integrations > Anthropic > Packages > flyteplugins.anthropic > Agent**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/anthropic/packages/flyteplugins.anthropic/agent ===
# Agent
**Package:** `flyteplugins.anthropic`
A Claude agent configuration.
This class represents the configuration for a Claude agent, including
the model to use, system instructions, and available tools.
Attributes:
name: A human-readable name for this agent. Used for logging and
identification only; not sent to the API.
instructions: The system prompt passed to Claude on every turn.
Describes the agent's role, tone, and constraints.
model: The Claude model ID to use, e.g. `"claude-sonnet-4-20250514"`.
tools: List of `FunctionTool` instances the agent can invoke.
Create tools with `function_tool()`.
max_tokens: Maximum number of tokens in each Claude response.
max_iterations: Maximum number of tool-call / response cycles before
`run_agent` returns with a timeout message.
## Parameters
```python
class Agent(
name: str,
instructions: str,
model: str,
tools: list[flyteplugins.anthropic.agents._function_tools.FunctionTool],
max_tokens: int,
max_iterations: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `instructions` | `str` | |
| `model` | `str` | |
| `tools` | `list[flyteplugins.anthropic.agents._function_tools.FunctionTool]` | |
| `max_tokens` | `int` | |
| `max_iterations` | `int` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Anthropic > Packages > flyteplugins.anthropic > Agent > Methods > get_anthropic_tools()** | Get tool definitions in Anthropic format. |
### get_anthropic_tools()
```python
def get_anthropic_tools()
```
Get tool definitions in Anthropic format.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery ===
# BigQuery
## Subpages
- **Integrations > BigQuery > Classes**
- **Integrations > BigQuery > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConfig** |Configuration for a BigQuery task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector** | |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > BigQuery > Packages > flyteplugins.bigquery** | BigQuery connector plugin for Flyte. |
## Subpages
- **Integrations > BigQuery > Packages > flyteplugins.bigquery**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/packages/flyteplugins.bigquery ===
# flyteplugins.bigquery
BigQuery connector plugin for Flyte.
This plugin provides integration between Flyte tasks and Google BigQuery,
enabling you to run parameterized SQL queries as Flyte tasks with full
observability, retries, and caching.
Key features:
- Parameterized SQL queries with typed inputs
- Returns query results as DataFrames
- Automatic links to the BigQuery job console in the Flyte UI
- Query cancellation on task abort
Basic usage example:
```python
import flyte
from flyte.io import DataFrame
from flyteplugins.bigquery import BigQueryConfig, BigQueryTask
config = BigQueryConfig(
ProjectID="my-gcp-project",
Location="US",
)
query_task = BigQueryTask(
name="count_events",
query_template="SELECT COUNT(*) AS total FROM `{ds}.events` WHERE date = @date",
plugin_config=config,
inputs={"date": str},
output_dataframe_type=DataFrame[dict],
)
@flyte.task
def run_query(date: str) -> DataFrame[dict]:
return query_task(date=date)
```
## Directory
### Classes
| Class | Description |
|-|-|
| [`BigQueryConfig`](bigqueryconfig/page.md) | Configuration for a BigQuery task. |
| [`BigQueryConnector`](bigqueryconnector/page.md) | |
| [`BigQueryTask`](bigquerytask/page.md) | |
## Subpages
- **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConfig**
- **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector**
- **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/packages/flyteplugins.bigquery/bigqueryconfig ===
# BigQueryConfig
**Package:** `flyteplugins.bigquery`
Configuration for a BigQuery task.
Attributes:
ProjectID: The Google Cloud project ID that owns the BigQuery dataset.
Location: The geographic location of the dataset, e.g. `"US"` or `"EU"`.
Defaults to the project's default location if not specified.
QueryJobConfig: Optional advanced job configuration passed directly to the
BigQuery client. Use this to set query parameters, destination tables,
time partitioning, etc.
## Parameters
```python
class BigQueryConfig(
ProjectID: str,
Location: typing.Optional[str],
QueryJobConfig: typing.Optional[google.cloud.bigquery.job.query.QueryJobConfig],
)
```
| Parameter | Type | Description |
|-|-|-|
| `ProjectID` | `str` | |
| `Location` | `typing.Optional[str]` | |
| `QueryJobConfig` | `typing.Optional[google.cloud.bigquery.job.query.QueryJobConfig]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/packages/flyteplugins.bigquery/bigqueryconnector ===
# BigQueryConnector
**Package:** `flyteplugins.bigquery`
## Methods
| Method | Description |
|-|-|
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector > Methods > create()** | Return a resource meta that can be used to get the status of the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector > Methods > delete()** | Delete the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector > Methods > get()** | Return the status of the task, and return the outputs in some cases. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector > Methods > get_logs()** | Return the metrics for the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryConnector > Methods > get_metrics()** | Return the metrics for the task. |
### create()
```python
def create(
task_template: flyteidl2.core.tasks_pb2.TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
google_application_credentials: typing.Optional[str],
kwargs,
) -> flyteplugins.bigquery.connector.BigQueryMetadata
```
Return a resource meta that can be used to get the status of the task.
| Parameter | Type | Description |
|-|-|-|
| `task_template` | `flyteidl2.core.tasks_pb2.TaskTemplate` | |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Any]]` | |
| `google_application_credentials` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### delete()
```python
def delete(
resource_meta: flyteplugins.bigquery.connector.BigQueryMetadata,
google_application_credentials: typing.Optional[str],
kwargs,
)
```
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.bigquery.connector.BigQueryMetadata` | |
| `google_application_credentials` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### get()
```python
def get(
resource_meta: flyteplugins.bigquery.connector.BigQueryMetadata,
google_application_credentials: typing.Optional[str],
kwargs,
) -> flyte.connectors._connector.Resource
```
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.bigquery.connector.BigQueryMetadata` | |
| `google_application_credentials` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### get_logs()
```python
def get_logs(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskLogsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get_metrics()
```python
def get_metrics(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskMetricsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/bigquery/packages/flyteplugins.bigquery/bigquerytask ===
# BigQueryTask
**Package:** `flyteplugins.bigquery`
## Parameters
```python
class BigQueryTask(
name: str,
query_template: str,
plugin_config: flyteplugins.bigquery.task.BigQueryConfig,
inputs: typing.Optional[typing.Dict[str, typing.Type]],
output_dataframe_type: typing.Optional[typing.Type[flyte.io._dataframe.dataframe.DataFrame]],
google_application_credentials: typing.Optional[str],
kwargs,
)
```
To be used to query BigQuery Tables.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The Name of this task, should be unique in the project |
| `query_template` | `str` | The actual query to run. We use Flyte's Golang templating format for Query templating. Refer to the templating documentation |
| `plugin_config` | `flyteplugins.bigquery.task.BigQueryConfig` | BigQueryConfig object |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Type]]` | Name and type of inputs specified as an ordered dictionary |
| `output_dataframe_type` | `typing.Optional[typing.Type[flyte.io._dataframe.dataframe.DataFrame]]` | If some data is produced by this query, then you can specify the output dataframe type. |
| `google_application_credentials` | `typing.Optional[str]` | The name of the secret containing the Google Application Credentials. |
| `kwargs` | `**kwargs` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `native_interface` | `None` | |
| `source_file` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > config()** | Returns additional configuration for the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > container_args()** | Returns the container args for the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > custom_config()** | Returns additional configuration for the task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > execute()** | |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > forward()** | Think of this as a local execute method for your task. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > override()** | Override various parameters of the task template. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > post()** | This is the postexecute function that will be. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > pre()** | This is the preexecute function that will be. |
| **Integrations > BigQuery > Packages > flyteplugins.bigquery > BigQueryTask > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
sctx: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: flyte.models.SerializationContext,
) -> typing.Optional[typing.Dict[str, typing.Any]]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
kwargs,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
when not in a Flyte task execution context. See the implementation below for an example.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: flyte.models.SerializationContext,
) -> typing.Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen ===
# Code generation
## Subpages
- **Integrations > Code generation > Classes**
- **Integrations > Code generation > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen > AutoCoderAgent** |Agent for single-file Python code generation with automatic testing and iteration. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeGenEvalResult** |Result from code generation and evaluation. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodePlan** |Structured plan for the code solution. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeSolution** |Structured code solution. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > ErrorDiagnosis** |Structured diagnosis of execution errors. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > ImageConfig** |Configuration for Docker image building at runtime. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen** | |
## Subpages
- **Integrations > Code generation > Packages > flyteplugins.codegen**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen ===
# flyteplugins.codegen
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen > AutoCoderAgent** | Agent for single-file Python code generation with automatic testing and iteration. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeGenEvalResult** | Result from code generation and evaluation. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodePlan** | Structured plan for the code solution. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeSolution** | Structured code solution. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > ErrorDiagnosis** | Structured diagnosis of execution errors. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > ImageConfig** | Configuration for Docker image building at runtime. |
## Subpages
- **Integrations > Code generation > Packages > flyteplugins.codegen > AutoCoderAgent**
- **Integrations > Code generation > Packages > flyteplugins.codegen > CodeGenEvalResult**
- **Integrations > Code generation > Packages > flyteplugins.codegen > CodePlan**
- **Integrations > Code generation > Packages > flyteplugins.codegen > CodeSolution**
- **Integrations > Code generation > Packages > flyteplugins.codegen > ErrorDiagnosis**
- **Integrations > Code generation > Packages > flyteplugins.codegen > ImageConfig**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/autocoderagent ===
# AutoCoderAgent
**Package:** `flyteplugins.codegen`
Agent for single-file Python code generation with automatic testing and iteration.
Generates a single Python script, builds a sandbox image with the required
dependencies, runs pytest-based tests, and iterates until tests pass.
Uses Sandbox internally for isolated code execution.
Args:
name: Name for the agent (used in image naming and logging).
model: LLM model to use (required). Must support structured outputs.
For backend="litellm" (default): e.g. "gpt-4.1", "claude-sonnet-4-20250514".
For backend="claude": a Claude model ("sonnet", "opus", "haiku").
system_prompt: Optional system prompt to use for LLM. If not provided,
a default prompt with structured output requirements is used.
api_key: Optional environment variable name for LLM API key.
api_base: Optional base URL for LLM API.
litellm_params: Optional dict of additional parameters to pass to LiteLLM calls.
base_packages: Optional list of base packages to install in the sandbox.
resources: Optional resources for sandbox execution (default: cpu=1, 1Gi).
image_config: Optional image configuration for sandbox execution.
max_iterations: Maximum number of generate-test-fix iterations. Defaults to 10.
max_sample_rows: Optional maximum number of rows to use for sample data. Defaults to 100.
skip_tests: Optional flag to skip testing. Defaults to False.
sandbox_retries: Number of Flyte task-level retries for each sandbox execution. Defaults to 0.
timeout: Timeout in seconds for sandboxes. Defaults to None.
env_vars: Environment variables to pass to sandboxes.
secrets: flyte.Secret objects to make available to sandboxes.
cache: CacheRequest for sandboxes: "auto", "override", or "disable". Defaults to "auto".
backend: Execution backend: "litellm" (default) or "claude".
agent_max_turns: Maximum agent turns when backend="claude". Defaults to 50.
Example::
from flyte.sandbox import sandbox_environment
from flyteplugins.codegen import AutoCoderAgent
agent = AutoCoderAgent(
model="gpt-4.1",
base_packages=["pandas"],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
env = flyte.TaskEnvironment(
name="my-env",
depends_on=[sandbox_environment],
)
@env.task
async def my_task(data_file: File) -> float:
result = await agent.generate.aio(
prompt="Process CSV data",
samples={"csv": data_file},
outputs={"total": float},
)
return await result.run.aio()
## Parameters
```python
class AutoCoderAgent(
model: str,
name: str,
system_prompt: typing.Optional[str],
api_key: typing.Optional[str],
api_base: typing.Optional[str],
litellm_params: typing.Optional[dict],
base_packages: typing.Optional[list[str]],
resources: typing.Optional[flyte._resources.Resources],
image_config: typing.Optional[flyte.sandbox._code_sandbox.ImageConfig],
max_iterations: int,
max_sample_rows: int,
skip_tests: bool,
sandbox_retries: int,
timeout: typing.Optional[int],
env_vars: typing.Optional[dict[str, str]],
secrets: typing.Optional[list],
cache: str,
backend: typing.Literal['litellm', 'claude'],
agent_max_turns: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `model` | `str` | |
| `name` | `str` | |
| `system_prompt` | `typing.Optional[str]` | |
| `api_key` | `typing.Optional[str]` | |
| `api_base` | `typing.Optional[str]` | |
| `litellm_params` | `typing.Optional[dict]` | |
| `base_packages` | `typing.Optional[list[str]]` | |
| `resources` | `typing.Optional[flyte._resources.Resources]` | |
| `image_config` | `typing.Optional[flyte.sandbox._code_sandbox.ImageConfig]` | |
| `max_iterations` | `int` | |
| `max_sample_rows` | `int` | |
| `skip_tests` | `bool` | |
| `sandbox_retries` | `int` | |
| `timeout` | `typing.Optional[int]` | |
| `env_vars` | `typing.Optional[dict[str, str]]` | |
| `secrets` | `typing.Optional[list]` | |
| `cache` | `str` | |
| `backend` | `typing.Literal['litellm', 'claude']` | |
| `agent_max_turns` | `int` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen > AutoCoderAgent > Methods > generate()** | Generate and evaluate code in an isolated sandbox. |
### generate()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .generate.aio()`.
```python
def generate(
prompt: str,
schema: typing.Optional[str],
constraints: typing.Optional[list[str]],
samples: typing.Optional[dict[str, pandas.core.frame.DataFrame | flyte.io._file.File]],
inputs: typing.Optional[dict[str, type]],
outputs: typing.Optional[dict[str, type]],
) -> flyteplugins.codegen.core.types.CodeGenEvalResult
```
Generate and evaluate code in an isolated sandbox.
Each call is independent with its own sandbox, packages and execution environment.
| Parameter | Type | Description |
|-|-|-|
| `prompt` | `str` | The prompt to generate code from. |
| `schema` | `typing.Optional[str]` | Optional free-form context about data formats, structures or schemas. Included verbatim in the LLM prompt. Use for input formats, output schemas, database schemas or any structural context the LLM needs to generate code. |
| `constraints` | `typing.Optional[list[str]]` | Optional list of constraints or requirements. |
| `samples` | `typing.Optional[dict[str, pandas.core.frame.DataFrame \| flyte.io._file.File]]` | Optional dict of sample data. Each value is sampled and included in the LLM prompt for context, and converted to a File input for the sandbox. Values are used as defaults at runtime β override them when calling `result.run()` or `result.as_task()`. Supported types: File, pd.DataFrame. |
| `inputs` | `typing.Optional[dict[str, type]]` | Optional dict declaring non-sample CLI argument types (e.g., `{"threshold": float, "mode": str}`). Sample entries are automatically added as File inputs β don't redeclare them here. Supported types: str, int, float, bool, File. |
| `outputs` | `typing.Optional[dict[str, type]]` | Optional dict defining output types (e.g., `{"result": str, "report": File}`). Supported types: str, int, float, bool, datetime, timedelta, File. |
**Returns:** CodeGenEvalResult with solution and execution details.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/codegenevalresult ===
# CodeGenEvalResult
**Package:** `flyteplugins.codegen`
Result from code generation and evaluation.
## Parameters
```python
class CodeGenEvalResult(
plan: typing.Optional[flyteplugins.codegen.core.types.CodePlan],
solution: flyteplugins.codegen.core.types.CodeSolution,
tests: typing.Optional[str],
success: bool,
output: str,
exit_code: int,
error: typing.Optional[str],
attempts: int,
conversation_history: list[dict[str, str]],
detected_packages: list[str],
detected_system_packages: list[str],
image: typing.Optional[str],
total_input_tokens: int,
total_output_tokens: int,
declared_inputs: typing.Optional[dict[str, type]],
declared_outputs: typing.Optional[dict[str, type]],
data_context: typing.Optional[str],
original_samples: typing.Optional[dict[str, flyte.io._file.File]],
generated_schemas: typing.Optional[dict[str, str]],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `plan` | `typing.Optional[flyteplugins.codegen.core.types.CodePlan]` | |
| `solution` | `flyteplugins.codegen.core.types.CodeSolution` | |
| `tests` | `typing.Optional[str]` | |
| `success` | `bool` | |
| `output` | `str` | |
| `exit_code` | `int` | |
| `error` | `typing.Optional[str]` | |
| `attempts` | `int` | |
| `conversation_history` | `list[dict[str, str]]` | |
| `detected_packages` | `list[str]` | Language packages detected by LLM from imports |
| `detected_system_packages` | `list[str]` | System packages detected by LLM |
| `image` | `typing.Optional[str]` | The Flyte Image built with all dependencies |
| `total_input_tokens` | `int` | Total input tokens used across all LLM calls |
| `total_output_tokens` | `int` | Total output tokens used across all LLM calls |
| `declared_inputs` | `typing.Optional[dict[str, type]]` | Input types (user-provided or inferred from samples) |
| `declared_outputs` | `typing.Optional[dict[str, type]]` | Output types declared by user |
| `data_context` | `typing.Optional[str]` | Extracted data context (schema, stats, patterns, samples) used for code generation |
| `original_samples` | `typing.Optional[dict[str, flyte.io._file.File]]` | Sample data converted to Files (defaults for run()/as_task()) |
| `generated_schemas` | `typing.Optional[dict[str, str]]` | Auto-generated Pandera schemas (as Python code strings) for validating data inputs |
## Methods
| Method | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeGenEvalResult > Methods > as_task()** | Create a sandbox that runs the generated code in an isolated sandbox. |
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeGenEvalResult > Methods > run()** | Run generated code in an isolated sandbox (one-off execution). |
### as_task()
```python
def as_task(
name: str,
resources: typing.Optional[flyte._resources.Resources],
retries: int,
timeout: typing.Optional[int],
env_vars: typing.Optional[dict[str, str]],
secrets: typing.Optional[list],
cache: str,
)
```
Create a sandbox that runs the generated code in an isolated sandbox.
The generated code will write outputs to /var/outputs/{output_name} files.
Returns a callable wrapper that automatically provides the script file.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name for the sandbox |
| `resources` | `typing.Optional[flyte._resources.Resources]` | Optional resources for the task |
| `retries` | `int` | Number of retries for the task. Defaults to 0. |
| `timeout` | `typing.Optional[int]` | Timeout in seconds. Defaults to None. |
| `env_vars` | `typing.Optional[dict[str, str]]` | Environment variables to pass to the sandbox. |
| `secrets` | `typing.Optional[list]` | flyte.Secret objects to make available. |
| `cache` | `str` | CacheRequest: "auto", "override", or "disable". Defaults to "auto". |
**Returns:** Callable task wrapper with the default inputs baked in. Call with your other declared inputs.
### run()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .run.aio()`.
```python
def run(
name: str,
resources: typing.Optional[flyte._resources.Resources],
retries: int,
timeout: typing.Optional[int],
env_vars: typing.Optional[dict[str, str]],
secrets: typing.Optional[list],
cache: str,
overrides,
) -> typing.Any
```
Run generated code in an isolated sandbox (one-off execution).
If samples were provided during generate(), they are used as defaults.
Override any input by passing it as a keyword argument. If no samples
exist, all declared inputs must be provided via `**overrides`.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Name for the sandbox |
| `resources` | `typing.Optional[flyte._resources.Resources]` | Optional resources for the task |
| `retries` | `int` | Number of retries for the task. Defaults to 0. |
| `timeout` | `typing.Optional[int]` | Timeout in seconds. Defaults to None. |
| `env_vars` | `typing.Optional[dict[str, str]]` | Environment variables to pass to the sandbox. |
| `secrets` | `typing.Optional[list]` | flyte.Secret objects to make available. |
| `cache` | `str` | CacheRequest: "auto", "override", or "disable". Defaults to "auto". |
| `overrides` | | |
**Returns:** Tuple of typed outputs.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/codeplan ===
# CodePlan
**Package:** `flyteplugins.codegen`
Structured plan for the code solution.
## Parameters
```python
class CodePlan(
description: str,
approach: str,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `description` | `str` | Overall description of the solution |
| `approach` | `str` | High-level approach and algorithm to solve the problem |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/codesolution ===
# CodeSolution
**Package:** `flyteplugins.codegen`
Structured code solution.
## Parameters
```python
class CodeSolution(
language: str,
code: str,
system_packages: list[str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `language` | `str` | Programming language |
| `code` | `str` | Complete executable code including imports and dependencies |
| `system_packages` | `list[str]` | System packages needed (e.g., gcc, build-essential, curl) |
## Methods
| Method | Description |
|-|-|
| **Integrations > Code generation > Packages > flyteplugins.codegen > CodeSolution > Methods > normalize_language()** | |
### normalize_language()
```python
def normalize_language(
v: str,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `v` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/errordiagnosis ===
# ErrorDiagnosis
**Package:** `flyteplugins.codegen`
Structured diagnosis of execution errors.
## Parameters
```python
class ErrorDiagnosis(
failures: list[flyteplugins.codegen.core.types.TestFailure],
needs_system_packages: list[str],
needs_language_packages: list[str],
needs_additional_commands: list[str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `failures` | `list[flyteplugins.codegen.core.types.TestFailure]` | Individual test failures with their diagnoses |
| `needs_system_packages` | `list[str]` | System packages needed (e.g., gcc, pkg-config). |
| `needs_language_packages` | `list[str]` | Language packages needed. |
| `needs_additional_commands` | `list[str]` | Additional RUN commands (e.g., apt-get update, mkdir /data, wget files). |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/codegen/packages/flyteplugins.codegen/imageconfig ===
# ImageConfig
**Package:** `flyteplugins.codegen`
Configuration for Docker image building at runtime.
## Parameters
```python
class ImageConfig(
registry: typing.Optional[str],
registry_secret: typing.Optional[str],
python_version: typing.Optional[tuple[int, int]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `registry` | `typing.Optional[str]` | |
| `registry_secret` | `typing.Optional[str]` | |
| `python_version` | `typing.Optional[tuple[int, int]]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask ===
# Dask
## Subpages
- **Integrations > Dask > Classes**
- **Integrations > Dask > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Dask > Packages > flyteplugins.dask > Dask** |Configuration for the dask task. |
| **Integrations > Dask > Packages > flyteplugins.dask > Scheduler** |Configuration for the scheduler pod. |
| **Integrations > Dask > Packages > flyteplugins.dask > WorkerGroup** |Configuration for a group of dask worker pods. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Dask > Packages > flyteplugins.dask** | |
## Subpages
- **Integrations > Dask > Packages > flyteplugins.dask**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/packages/flyteplugins.dask ===
# flyteplugins.dask
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Dask > Packages > flyteplugins.dask > Dask** | Configuration for the dask task. |
| **Integrations > Dask > Packages > flyteplugins.dask > Scheduler** | Configuration for the scheduler pod. |
| **Integrations > Dask > Packages > flyteplugins.dask > WorkerGroup** | Configuration for a group of dask worker pods. |
## Subpages
- **Integrations > Dask > Packages > flyteplugins.dask > Dask**
- **Integrations > Dask > Packages > flyteplugins.dask > Scheduler**
- **Integrations > Dask > Packages > flyteplugins.dask > WorkerGroup**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/packages/flyteplugins.dask/dask ===
# Dask
**Package:** `flyteplugins.dask`
Configuration for the dask task
## Parameters
```python
class Dask(
scheduler: flyteplugins.dask.task.Scheduler,
workers: flyteplugins.dask.task.WorkerGroup,
)
```
| Parameter | Type | Description |
|-|-|-|
| `scheduler` | `flyteplugins.dask.task.Scheduler` | Configuration for the scheduler pod. Optional, defaults to `Scheduler()`. |
| `workers` | `flyteplugins.dask.task.WorkerGroup` | Configuration for the pods of the default worker group. Optional, defaults to `WorkerGroup()`. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/packages/flyteplugins.dask/scheduler ===
# Scheduler
**Package:** `flyteplugins.dask`
Configuration for the scheduler pod
## Parameters
```python
class Scheduler(
image: typing.Optional[str],
resources: typing.Optional[flyte._resources.Resources],
)
```
| Parameter | Type | Description |
|-|-|-|
| `image` | `typing.Optional[str]` | Custom image to use. If `None`, will use the same image the task was registered with. Optional, defaults to None. The image must have `dask[distributed]` installed and should have the same Python environment as the rest of the cluster (job runner pod + worker pods). |
| `resources` | `typing.Optional[flyte._resources.Resources]` | Resources to request for the scheduler pod. Optional, defaults to None. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/dask/packages/flyteplugins.dask/workergroup ===
# WorkerGroup
**Package:** `flyteplugins.dask`
Configuration for a group of dask worker pods
## Parameters
```python
class WorkerGroup(
number_of_workers: typing.Optional[int],
image: typing.Optional[str],
resources: typing.Optional[flyte._resources.Resources],
)
```
| Parameter | Type | Description |
|-|-|-|
| `number_of_workers` | `typing.Optional[int]` | Number of workers to use. Optional, defaults to 1. |
| `image` | `typing.Optional[str]` | Custom image to use. If `None`, will use the same image the task was registered with. Optional, defaults to None. The image must have `dask[distributed]` installed. The provided image should have the same Python environment as the job runner/driver as well as the scheduler. |
| `resources` | `typing.Optional[flyte._resources.Resources]` | Resources to request for the worker pods. Optional, defaults to None. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks ===
# Databricks
## Subpages
- **Integrations > Databricks > Classes**
- **Integrations > Databricks > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Databricks > Packages > flyteplugins.databricks > Databricks** |Configuration for a Databricks task. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Databricks > Packages > flyteplugins.databricks** | Databricks connector plugin for Flyte. |
## Subpages
- **Integrations > Databricks > Packages > flyteplugins.databricks**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks/packages/flyteplugins.databricks ===
# flyteplugins.databricks
Databricks connector plugin for Flyte.
This plugin provides integration between Flyte tasks and Databricks,
enabling you to run PySpark jobs on Databricks clusters as Flyte tasks
with full observability, retries, and caching.
Key features:
- Run PySpark tasks natively on Databricks clusters
- Configurable cluster spec via the Databricks Jobs API
- Automatic job lifecycle management: create, poll, cancel
- Automatic links to the Databricks job run UI in the Flyte UI
Basic usage example:
```python
import flyte
from flyteplugins.databricks import Databricks
databricks_config = Databricks(
spark_conf={"spark.executor.memory": "4g"},
databricks_conf={
"run_name": "my_job",
"new_cluster": {
"spark_version": "13.3.x-scala2.12",
"node_type_id": "i3.xlarge",
"num_workers": 2,
},
},
databricks_instance="myorg.cloud.databricks.com",
databricks_token="databricks_token_secret",
)
env = flyte.TaskEnvironment(
name="databricks_env",
plugin_config=databricks_config,
image=flyte.Image.from_debian_base(name="pyspark").with_pip_packages(
"flyteplugins-databricks"
),
)
@env.task
def process_data(input_path: str) -> int:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.read.parquet(input_path)
return df.count()
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Databricks > Packages > flyteplugins.databricks > Databricks** | Configuration for a Databricks task. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector** | |
## Subpages
- **Integrations > Databricks > Packages > flyteplugins.databricks > Databricks**
- **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks/packages/flyteplugins.databricks/databricks ===
# Databricks
**Package:** `flyteplugins.databricks`
Configuration for a Databricks task.
Tasks configured with this will execute natively on Databricks as a
distributed PySpark job. Extends `Spark` with Databricks-specific
cluster and authentication settings.
Attributes:
spark_conf: Spark configuration key-value pairs, e.g.
`{"spark.executor.memory": "4g"}`.
hadoop_conf: Hadoop configuration key-value pairs.
executor_path: Path to the Python binary used for PySpark execution.
Defaults to the interpreter path from the serialization context.
applications_path: Path to the main application file. Defaults to
the task entrypoint path.
driver_pod: Pod template applied to the Spark driver pod.
executor_pod: Pod template applied to the Spark executor pods.
databricks_conf: Databricks job configuration dict compliant with
the Databricks Jobs API v2.1 (also supports v2.0 use cases).
Typically includes `new_cluster` or `existing_cluster_id`,
`run_name`, and other job settings.
databricks_instance: Domain name of your Databricks deployment,
e.g. `"myorg.cloud.databricks.com"`.
databricks_token: Name of the Flyte secret containing the Databricks
API token used for authentication.
## Parameters
```python
class Databricks(
spark_conf: typing.Optional[typing.Dict[str, str]],
hadoop_conf: typing.Optional[typing.Dict[str, str]],
executor_path: typing.Optional[str],
applications_path: typing.Optional[str],
driver_pod: typing.Optional[flyte._pod.PodTemplate],
executor_pod: typing.Optional[flyte._pod.PodTemplate],
databricks_conf: typing.Optional[typing.Dict[str, typing.Union[str, dict]]],
databricks_instance: typing.Optional[str],
databricks_token: typing.Optional[str],
)
```
| Parameter | Type | Description |
|-|-|-|
| `spark_conf` | `typing.Optional[typing.Dict[str, str]]` | |
| `hadoop_conf` | `typing.Optional[typing.Dict[str, str]]` | |
| `executor_path` | `typing.Optional[str]` | |
| `applications_path` | `typing.Optional[str]` | |
| `driver_pod` | `typing.Optional[flyte._pod.PodTemplate]` | |
| `executor_pod` | `typing.Optional[flyte._pod.PodTemplate]` | |
| `databricks_conf` | `typing.Optional[typing.Dict[str, typing.Union[str, dict]]]` | |
| `databricks_instance` | `typing.Optional[str]` | |
| `databricks_token` | `typing.Optional[str]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/databricks/packages/flyteplugins.databricks/databricksconnector ===
# DatabricksConnector
**Package:** `flyteplugins.databricks`
## Methods
| Method | Description |
|-|-|
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector > Methods > create()** | Return a resource meta that can be used to get the status of the task. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector > Methods > delete()** | Delete the task. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector > Methods > get()** | Return the status of the task, and return the outputs in some cases. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector > Methods > get_logs()** | Return the metrics for the task. |
| **Integrations > Databricks > Packages > flyteplugins.databricks > DatabricksConnector > Methods > get_metrics()** | Return the metrics for the task. |
### create()
```python
def create(
task_template: flyteidl2.core.tasks_pb2.TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
databricks_token: typing.Optional[str],
kwargs,
) -> flyteplugins.databricks.connector.DatabricksJobMetadata
```
Return a resource meta that can be used to get the status of the task.
| Parameter | Type | Description |
|-|-|-|
| `task_template` | `flyteidl2.core.tasks_pb2.TaskTemplate` | |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Any]]` | |
| `databricks_token` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### delete()
```python
def delete(
resource_meta: flyteplugins.databricks.connector.DatabricksJobMetadata,
databricks_token: typing.Optional[str],
kwargs,
)
```
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.databricks.connector.DatabricksJobMetadata` | |
| `databricks_token` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### get()
```python
def get(
resource_meta: flyteplugins.databricks.connector.DatabricksJobMetadata,
databricks_token: typing.Optional[str],
kwargs,
) -> flyte.connectors._connector.Resource
```
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.databricks.connector.DatabricksJobMetadata` | |
| `databricks_token` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
### get_logs()
```python
def get_logs(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskLogsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get_metrics()
```python
def get_metrics(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskMetricsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/gemini ===
# Gemini
## Subpages
- **Integrations > Gemini > Classes**
- **Integrations > Gemini > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/gemini/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Gemini > Packages > flyteplugins.gemini > Agent** |A Gemini agent configuration. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/gemini/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Gemini > Packages > flyteplugins.gemini** | Google Gemini plugin for Flyte. |
## Subpages
- **Integrations > Gemini > Packages > flyteplugins.gemini**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/gemini/packages/flyteplugins.gemini ===
# flyteplugins.gemini
Google Gemini plugin for Flyte.
This plugin provides integration between Flyte tasks and Google's Gemini API,
enabling you to use Flyte tasks as tools for Gemini agents. Tool calls run with
full Flyte observability, retries, and caching.
Key features:
- Use any Flyte task as a Gemini tool via `function_tool`
- Full agent loop with automatic tool dispatch via `run_agent`
- Configurable agent via `Agent` (model, system prompt, tools, iteration limits)
Basic usage example:
```python
import flyte
from flyteplugins.gemini import Agent, function_tool, run_agent
env = flyte.TaskEnvironment(
name="agent_env",
image=flyte.Image.from_debian_base(name="agent").with_pip_packages(
"flyteplugins-gemini"
),
)
@env.task
async def get_weather(city: str) -> str:
'''Get the current weather for a city.'''
return f"Weather in {city}: sunny, 22Β°C"
weather_tool = function_tool(get_weather)
@env.task
async def run_weather_agent(question: str) -> str:
return await run_agent(
prompt=question,
tools=[weather_tool],
model="gemini-2.5-flash",
)
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Gemini > Packages > flyteplugins.gemini > Agent** | A Gemini agent configuration. |
### Methods
| Method | Description |
|-|-|
| **Integrations > Gemini > Packages > flyteplugins.gemini > Methods > function_tool()** | Convert a function or Flyte task to a Gemini-compatible tool. |
| **Integrations > Gemini > Packages > flyteplugins.gemini > Methods > run_agent()** | Run a Gemini agent with the given tools and prompt. |
## Methods
#### function_tool()
```python
def function_tool(
func: typing.Union[flyte._task.AsyncFunctionTaskTemplate, typing.Callable, NoneType],
name: str | None,
description: str | None,
) -> FunctionTool | partial[FunctionTool]
```
Convert a function or Flyte task to a Gemini-compatible tool.
This function converts a Python function, @flyte.trace decorated function,
or Flyte task into a FunctionTool that can be used with Gemini's function calling API.
The input_schema is derived via the Flyte type engine, producing JSON schema.
This ensures that Literal types, dataclasses, FlyteFile, and other Flyte-native
types are represented correctly.
For @flyte.trace decorated functions, the tracing context is preserved
automatically since functools.wraps maintains the original function's metadata.
Example:
```python
@env.task
async def get_weather(city: str) -> str:
'''Get the current weather for a city.'''
return f"Weather in {city}: sunny"
tool = function_tool(get_weather)
```
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Union[flyte._task.AsyncFunctionTaskTemplate, typing.Callable, NoneType]` | The function or Flyte task to convert. |
| `name` | `str \| None` | Optional custom name for the tool. Defaults to the function name. |
| `description` | `str \| None` | Optional custom description. Defaults to the function's docstring. |
**Returns**
A FunctionTool instance that can be used with run_agent().
#### run_agent()
```python
def run_agent(
prompt: str,
tools: list[flyteplugins.gemini.agents._function_tools.FunctionTool] | None,
agent: flyteplugins.gemini.agents._function_tools.Agent | None,
model: str,
system: str | None,
max_output_tokens: int,
max_iterations: int,
api_key: str | None,
) -> str
```
Run a Gemini agent with the given tools and prompt.
This function creates a Gemini conversation loop that can use tools
to accomplish tasks. It handles the back-and-forth of function calls
and responses until the agent produces a final text response.
Example:
```python
result = await run_agent(
prompt="What's the weather in SF?",
tools=[function_tool(get_weather)],
)
```
| Parameter | Type | Description |
|-|-|-|
| `prompt` | `str` | The user prompt to send to the agent. |
| `tools` | `list[flyteplugins.gemini.agents._function_tools.FunctionTool] \| None` | List of FunctionTool instances to make available to the agent. |
| `agent` | `flyteplugins.gemini.agents._function_tools.Agent \| None` | Optional Agent configuration. If provided, overrides other params. |
| `model` | `str` | The Gemini model to use. |
| `system` | `str \| None` | Optional system prompt. |
| `max_output_tokens` | `int` | Maximum tokens in the response. |
| `max_iterations` | `int` | Maximum number of tool call iterations. |
| `api_key` | `str \| None` | Google API key. Defaults to GOOGLE_API_KEY env var. |
**Returns**
The final text response from the agent.
## Subpages
- **Integrations > Gemini > Packages > flyteplugins.gemini > Agent**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/gemini/packages/flyteplugins.gemini/agent ===
# Agent
**Package:** `flyteplugins.gemini`
A Gemini agent configuration.
This class represents the configuration for a Gemini agent, including
the model to use, system instructions, and available tools.
Attributes:
name: A human-readable name for this agent. Used for logging and
identification only; not sent to the API.
instructions: The system prompt passed to Gemini on every turn.
Describes the agent's role, tone, and constraints.
model: The Gemini model ID to use, e.g. `"gemini-2.5-flash"`.
tools: List of `FunctionTool` instances the agent can invoke.
Create tools with `function_tool()`.
max_output_tokens: Maximum number of tokens in each Gemini response.
max_iterations: Maximum number of function-call / response cycles before
`run_agent` returns with a timeout message.
## Parameters
```python
class Agent(
name: str,
instructions: str,
model: str,
tools: list[flyteplugins.gemini.agents._function_tools.FunctionTool],
max_output_tokens: int,
max_iterations: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `instructions` | `str` | |
| `model` | `str` | |
| `tools` | `list[flyteplugins.gemini.agents._function_tools.FunctionTool]` | |
| `max_output_tokens` | `int` | |
| `max_iterations` | `int` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Gemini > Packages > flyteplugins.gemini > Agent > Methods > get_gemini_tools()** | Get tool definitions in Gemini format. |
### get_gemini_tools()
```python
def get_gemini_tools()
```
Get tool definitions in Gemini format.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/hitl ===
# Human-in-the-Loop
## Subpages
- **Integrations > Human-in-the-Loop > Classes**
- **Integrations > Human-in-the-Loop > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/hitl/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Event** |An event that waits for human input via an embedded FastAPI app. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/hitl/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl** | Human-in-the-Loop (HITL) plugin for Flyte. |
## Subpages
- **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/hitl/packages/flyteplugins.hitl ===
# flyteplugins.hitl
Human-in-the-Loop (HITL) plugin for Flyte.
This plugin provides an event-based API for pausing workflows and waiting for human input.
## Basic usage:
```python
import flyte
import flyteplugins.hitl as hitl
task_env = flyte.TaskEnvironment(
name="my-hitl-workflow",
image=flyte.Image.from_debian_base(python_version=(3, 12)),
resources=flyte.Resources(cpu=1, memory="512Mi"),
depends_on=[hitl.env],
)
@task_env.task(report=True)
async def main() -> int:
# Create an event (this serves the app if not already running)
event = await hitl.new_event.aio(
"integer_input_event",
data_type=int,
scope="run",
prompt="What should I add to x?",
)
y = await event.wait.aio()
return y
```
## Features:
- Event-based API for human-in-the-loop workflows
- Web form for human input
- Programmatic API for automated input
- Support for int, float, str, and bool data types
- Crash-resilient polling with object storage
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Event** | An event that waits for human input via an embedded FastAPI app. |
### Methods
| Method | Description |
|-|-|
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Methods > new_event()** | Create a new human-in-the-loop event. |
### Variables
| Property | Type | Description |
|-|-|-|
| `env` | `TaskEnvironment` | |
## Methods
#### new_event()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await new_event.aio()`.
```python
def new_event(
name: str,
data_type: Type[T],
scope: EventScope,
prompt: str,
timeout_seconds: int,
poll_interval_seconds: int,
) -> Event[T]
```
Create a new human-in-the-loop event.
This is a convenience function that wraps Event.create().
Example:
# Async usage
event = await new_event.aio(
"approval_event",
data_type=bool,
scope="run",
prompt="Do you approve this action?",
)
approved = await event.wait.aio()
# Sync usage
event = new_event("value_event", data_type=int, scope="run", prompt="Enter a number")
value = event.wait()
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | A descriptive name for the event (used in logs and UI) |
| `data_type` | `Type[T]` | The expected type of the input (int, float, str, bool) |
| `scope` | `EventScope` | The scope of the event. Currently only "run" is supported. |
| `prompt` | `str` | The prompt to display to the human |
| `timeout_seconds` | `int` | Maximum time to wait for human input (default: 1 hour) |
| `poll_interval_seconds` | `int` | How often to check for a response (default: 5 seconds) |
**Returns**
An Event object that can be used to wait for the human input
## Subpages
- **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Event**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/hitl/packages/flyteplugins.hitl/event ===
# Event
**Package:** `flyteplugins.hitl`
An event that waits for human input via an embedded FastAPI app.
This class encapsulates the entire HITL functionality:
- Creates and serves a FastAPI app for receiving human input
- Provides endpoints for form-based and JSON-based submission
- Polls object storage for responses using durable sleep (crash-resilient)
The app is automatically served when the Event is created via `Event.create()`.
All infrastructure details (AppEnvironment, deployment) are abstracted away.
Example:
# Create an event (serves the app) and wait for input
event = await Event.create.aio(
"proceed_event",
scope="run",
prompt="What should I add to x?",
data_type=int,
)
result = await event.wait.aio()
# Or synchronously
event = Event.create("my_event", scope="run", prompt="Enter value", data_type=str)
value = event.wait()
## Parameters
```python
class Event(
name: str,
scope: EventScope,
data_type: Type[T],
prompt: str,
request_id: str,
endpoint: str,
request_path: str,
response_path: str,
timeout_seconds: int,
poll_interval_seconds: int,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `scope` | `EventScope` | |
| `data_type` | `Type[T]` | |
| `prompt` | `str` | |
| `request_id` | `str` | |
| `endpoint` | `str` | |
| `request_path` | `str` | |
| `response_path` | `str` | |
| `timeout_seconds` | `int` | |
| `poll_interval_seconds` | `int` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `api_url` | `None` | API endpoint for programmatic submission. |
| `endpoint` | `None` | Base endpoint of the HITL app. |
| `form_url` | `None` | URL where humans can submit input for this event. |
## Methods
| Method | Description |
|-|-|
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Event > Methods > create()** | Create a new human-in-the-loop event and serve the app. |
| **Integrations > Human-in-the-Loop > Packages > flyteplugins.hitl > Event > Methods > wait()** | Wait for human input and return the result. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Event.create.aio()`.
```python
def create(
cls,
name: str,
data_type: Type[T],
scope: EventScope,
prompt: str,
timeout_seconds: int,
poll_interval_seconds: int,
) -> 'Event[T]'
```
Create a new human-in-the-loop event and serve the app.
This method creates an event that waits for human input via the FastAPI app.
The app is automatically served if not already running. All infrastructure
details are abstracted away - you just get an event to wait on.
Example:
# Async usage
event = await Event.create.aio(
"approval_event",
scope="run",
prompt="Do you approve this action?",
data_type=bool,
)
approved = await event.wait.aio()
# Sync usage
event = Event.create("value_event", scope="run", prompt="Enter a number", data_type=int)
value = event.wait()
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | A descriptive name for the event (used in logs and UI) |
| `data_type` | `Type[T]` | The expected type of the input (int, float, str, bool) |
| `scope` | `EventScope` | The scope of the event. Currently only "run" is supported. |
| `prompt` | `str` | The prompt to display to the human |
| `timeout_seconds` | `int` | Maximum time to wait for human input (default: 1 hour) |
| `poll_interval_seconds` | `int` | How often to check for a response (default: 5 seconds) |
**Returns**
An Event object that can be used to wait for the human input
### wait()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await .wait.aio()`.
```python
def wait()
```
Wait for human input and return the result.
This method polls object storage for a response using durable sleep,
making it crash-resilient. If the task crashes and restarts, it will
resume polling from where it left off.
**Returns**
The value provided by the human, converted to the event's data_type
**Raises**
| Exception | Description |
|-|-|
| `TimeoutError` | If no response is received within the timeout |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl ===
# JSONL
## Subpages
- **Integrations > JSONL > Classes**
- **Integrations > JSONL > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir** |A directory of sharded JSONL files. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile** |A file type for JSONL (JSON Lines) files, backed by `orjson` for fast. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > JSONL > Packages > flyteplugins.jsonl** | |
## Subpages
- **Integrations > JSONL > Packages > flyteplugins.jsonl**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl/packages/flyteplugins.jsonl ===
# flyteplugins.jsonl
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir** | A directory of sharded JSONL files. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile** | A file type for JSONL (JSON Lines) files, backed by `orjson` for fast. |
## Subpages
- **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir**
- **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl/packages/flyteplugins.jsonl/jsonldir ===
# JsonlDir
**Package:** `flyteplugins.jsonl`
A directory of sharded JSONL files.
Provides transparent iteration across shards on read and automatic shard
rotation on write. Inherits all `Dir` capabilities (remote storage,
walk, download, etc.).
Shard files are named `part-00000.jsonl` (or `.jsonl.zst` for
compressed shards), zero-padded to 5 digits and sorted alphabetically
on read. Mixed compression within a single directory is supported.
Example (Async read)::
@env.task
async def process(d: JsonlDir):
async for record in d.iter_records():
print(record)
Example (Async write)::
@env.task
async def create() -> JsonlDir:
d = JsonlDir.new_remote("output_shards")
async with d.writer(max_records_per_shard=1000) as w:
for i in range(5000):
await w.write({"id": i})
return d
## Parameters
```python
class JsonlDir(
path: str,
name: typing.Optional[str],
format: str,
hash: typing.Optional[str],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `name` | `typing.Optional[str]` | |
| `format` | `str` | |
| `hash` | `typing.Optional[str]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `lazy_uploader` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > download()** | Asynchronously download the entire directory to a local path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > download_sync()** | Synchronously download the entire directory to a local path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > exists()** | Asynchronously check if the directory exists. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > exists_sync()** | Synchronously check if the directory exists. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > from_existing_remote()** | Create a Dir reference from an existing remote directory. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > from_local()** | Asynchronously create a new Dir by uploading a local directory to remote storage. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > from_local_sync()** | Synchronously create a new Dir by uploading a local directory to remote storage. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > get_file()** | Asynchronously get a specific file from the directory by name. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > get_file_sync()** | Synchronously get a specific file from the directory by name. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_arrow_batches()** | Async generator that yields Arrow RecordBatches across all shards. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_arrow_batches_sync()** | Sync generator that yields Arrow RecordBatches across all shards. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_batches()** | Async generator that yields lists of records in batches. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_batches_sync()** | Sync generator that yields lists of records in batches. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_records()** | Async generator that yields records from all shards in sorted order. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > iter_records_sync()** | Sync generator that yields records from all shards in sorted order. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > list_files()** | Asynchronously get a list of all files in the directory (non-recursive). |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > list_files_sync()** | Synchronously get a list of all files in the directory (non-recursive). |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > model_post_init()** | This function is meant to behave like a BaseModel method to initialise private attributes. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > new_remote()** | Create a new Dir reference for a remote directory that will be written to. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > pre_init()** | Internal: Pydantic validator to set default name from path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > schema_match()** | Internal: Check if incoming schema matches Dir schema. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > walk()** | Asynchronously walk through the directory and yield File objects. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > walk_sync()** | Synchronously walk through the directory and yield File objects. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > writer()** | Async context manager returning a `JsonlDirWriter`. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlDir > Methods > writer_sync()** | Sync context manager returning a `JsonlDirWriterSync`. |
### download()
```python
def download(
local_path: Optional[Union[str, Path]],
) -> str
```
Asynchronously download the entire directory to a local path.
Use this when you need to download all files in a directory to your local filesystem for processing.
Example (Async):
```python
@env.task
async def download_directory(d: Dir) -> str:
local_dir = await d.download()
# Process files in the local directory
return local_dir
```
Example (Async - Download to specific path):
```python
@env.task
async def download_to_path(d: Dir) -> str:
local_dir = await d.download("/tmp/my_data/")
return local_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the directory to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded directory
### download_sync()
```python
def download_sync(
local_path: Optional[Union[str, Path]],
) -> str
```
Synchronously download the entire directory to a local path.
Use this in non-async tasks when you need to download all files in a directory to your local filesystem.
Example (Sync):
```python
@env.task
def download_directory_sync(d: Dir) -> str:
local_dir = d.download_sync()
# Process files in the local directory
return local_dir
```
Example (Sync - Download to specific path):
```python
@env.task
def download_to_path_sync(d: Dir) -> str:
local_dir = d.download_sync("/tmp/my_data/")
return local_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the directory to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded directory
### exists()
```python
def exists()
```
Asynchronously check if the directory exists.
Example (Async):
```python
@env.task
async def check_directory(d: Dir) -> bool:
if await d.exists():
print("Directory exists!")
return True
return False
```
**Returns**
True if the directory exists, False otherwise
### exists_sync()
```python
def exists_sync()
```
Synchronously check if the directory exists.
Use this in non-async tasks or when you need synchronous directory existence checking.
Example (Sync):
```python
@env.task
def check_directory_sync(d: Dir) -> bool:
if d.exists_sync():
print("Directory exists!")
return True
return False
```
**Returns**
True if the directory exists, False otherwise
### from_existing_remote()
```python
def from_existing_remote(
remote_path: str,
dir_cache_key: Optional[str],
) -> Dir[T]
```
Create a Dir reference from an existing remote directory.
Use this when you want to reference a directory that already exists in remote storage without uploading it.
Example:
```python
@env.task
async def process_existing_directory() -> int:
d = Dir.from_existing_remote("s3://my-bucket/data/")
files = await d.list_files()
return len(files)
```
Example (With cache key):
```python
@env.task
async def process_with_cache_key() -> int:
d = Dir.from_existing_remote("s3://my-bucket/data/", dir_cache_key="abc123")
files = await d.list_files()
return len(files)
```
| Parameter | Type | Description |
|-|-|-|
| `remote_path` | `str` | The remote path to the existing directory |
| `dir_cache_key` | `Optional[str]` | Optional hash value to use for cache key computation. If not specified, the cache key will be computed based on the directory's attributes. |
**Returns:** A new Dir instance pointing to the existing remote directory
### from_local()
```python
def from_local(
local_path: Union[str, Path],
remote_destination: Optional[str],
dir_cache_key: Optional[str],
batch_size: Optional[int],
) -> Dir[T]
```
Asynchronously create a new Dir by uploading a local directory to remote storage.
Use this in async tasks when you have a local directory that needs to be uploaded to remote storage.
Example (Async):
```python
@env.task
async def upload_local_directory() -> Dir:
# Create a local directory with files
os.makedirs("/tmp/data_dir", exist_ok=True)
with open("/tmp/data_dir/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
remote_dir = await Dir.from_local("/tmp/data_dir/")
return remote_dir
```
Example (Async - With specific destination):
```python
@env.task
async def upload_to_specific_path() -> Dir:
remote_dir = await Dir.from_local("/tmp/data_dir/", "s3://my-bucket/data/")
return remote_dir
```
Example (Async - With cache key):
```python
@env.task
async def upload_with_cache_key() -> Dir:
remote_dir = await Dir.from_local("/tmp/data_dir/", dir_cache_key="my_cache_key_123")
return remote_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local directory |
| `remote_destination` | `Optional[str]` | Optional remote path to store the directory. If None, a path will be automatically generated. |
| `dir_cache_key` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. If not specified, the cache key will be based on directory attributes. |
| `batch_size` | `Optional[int]` | Optional concurrency limit for uploading files. If not specified, the default value is determined by the FLYTE_IO_BATCH_SIZE environment variable (default: 32). |
**Returns:** A new Dir instance pointing to the uploaded directory
### from_local_sync()
```python
def from_local_sync(
local_path: Union[str, Path],
remote_destination: Optional[str],
dir_cache_key: Optional[str],
) -> Dir[T]
```
Synchronously create a new Dir by uploading a local directory to remote storage.
Use this in non-async tasks when you have a local directory that needs to be uploaded to remote storage.
Example (Sync):
```python
@env.task
def upload_local_directory_sync() -> Dir:
# Create a local directory with files
os.makedirs("/tmp/data_dir", exist_ok=True)
with open("/tmp/data_dir/file1.txt", "w") as f:
f.write("data1")
# Upload to remote storage
remote_dir = Dir.from_local_sync("/tmp/data_dir/")
return remote_dir
```
Example (Sync - With specific destination):
```python
@env.task
def upload_to_specific_path_sync() -> Dir:
remote_dir = Dir.from_local_sync("/tmp/data_dir/", "s3://my-bucket/data/")
return remote_dir
```
Example (Sync - With cache key):
```python
@env.task
def upload_with_cache_key_sync() -> Dir:
remote_dir = Dir.from_local_sync("/tmp/data_dir/", dir_cache_key="my_cache_key_123")
return remote_dir
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local directory |
| `remote_destination` | `Optional[str]` | Optional remote path to store the directory. If None, a path will be automatically generated. |
| `dir_cache_key` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. If not specified, the cache key will be based on directory attributes. |
**Returns:** A new Dir instance pointing to the uploaded directory
### get_file()
```python
def get_file(
file_name: str,
) -> Optional[File[T]]
```
Asynchronously get a specific file from the directory by name.
Use this when you know the name of a specific file in the directory you want to access.
Example (Async):
```python
@env.task
async def read_specific_file(d: Dir) -> str:
file = await d.get_file("data.csv")
if file:
async with file.open("rb") as f:
content = await f.read()
return content.decode("utf-8")
return "File not found"
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `str` | The name of the file to get |
**Returns:** A File instance if the file exists, None otherwise
### get_file_sync()
```python
def get_file_sync(
file_name: str,
) -> Optional[File[T]]
```
Synchronously get a specific file from the directory by name.
Use this in non-async tasks when you know the name of a specific file in the directory you want to access.
Example (Sync):
```python
@env.task
def read_specific_file_sync(d: Dir) -> str:
file = d.get_file_sync("data.csv")
if file:
with file.open_sync("rb") as f:
content = f.read()
return content.decode("utf-8")
return "File not found"
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `str` | The name of the file to get |
**Returns:** A File instance if the file exists, None otherwise
### iter_arrow_batches()
```python
def iter_arrow_batches(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> AsyncGenerator[Any, None]
```
Async generator that yields Arrow RecordBatches across all shards.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | Max records per RecordBatch (default 65536). |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | `"raise"` (default), `"skip"`, or a callable. |
### iter_arrow_batches_sync()
```python
def iter_arrow_batches_sync(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> Generator[Any, None, None]
```
Sync generator that yields Arrow RecordBatches across all shards.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | Max records per RecordBatch (default 65536). |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | `"raise"` (default), `"skip"`, or a callable. |
### iter_batches()
```python
def iter_batches(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
prefetch: bool,
queue_size: int,
) -> AsyncGenerator[list[dict[str, Any]], None]
```
Async generator that yields lists of records in batches.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | Max records per batch (default 1000). |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | `"raise"` (default), `"skip"`, or a callable. |
| `prefetch` | `bool` | Overlap next-shard I/O with current-shard processing. |
| `queue_size` | `int` | Memory safety bound on the read-ahead buffer. |
### iter_batches_sync()
```python
def iter_batches_sync(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> Generator[list[dict[str, Any]], None, None]
```
Sync generator that yields lists of records in batches.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | Max records per batch (default 1000). |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | `"raise"` (default), `"skip"`, or a callable. |
### iter_records()
```python
def iter_records(
on_error: Literal['raise', 'skip'] | ErrorHandler,
prefetch: bool,
queue_size: int,
) -> AsyncGenerator[dict[str, Any], None]
```
Async generator that yields records from all shards in sorted order.
When *prefetch* is True (default), the next shard is read into a
bounded queue concurrently while the current shard is being yielded.
This overlaps network I/O with processing without buffering more
than one shard in memory.
| Parameter | Type | Description |
|-|-|-|
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | `"raise"` (default), `"skip"`, or a callable `(line_number, raw_line, exception) -> None`. |
| `prefetch` | `bool` | Overlap next-shard network I/O with current-shard processing for higher throughput. |
| `queue_size` | `int` | Memory safety bound on the read-ahead buffer (default 8192). |
### iter_records_sync()
```python
def iter_records_sync(
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> Generator[dict[str, Any], None, None]
```
Sync generator that yields records from all shards in sorted order.
| Parameter | Type | Description |
|-|-|-|
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | |
### list_files()
```python
def list_files()
```
Asynchronously get a list of all files in the directory (non-recursive).
Use this when you need a list of all files in the top-level directory at once.
Example (Async):
```python
@env.task
async def count_files(d: Dir) -> int:
files = await d.list_files()
return len(files)
```
Example (Async - Process files):
```python
@env.task
async def process_all_files(d: Dir) -> list[str]:
files = await d.list_files()
contents = []
for file in files:
async with file.open("rb") as f:
content = await f.read()
contents.append(content.decode("utf-8"))
return contents
```
**Returns**
A list of File objects for files in the top-level directory
### list_files_sync()
```python
def list_files_sync()
```
Synchronously get a list of all files in the directory (non-recursive).
Use this in non-async tasks when you need a list of all files in the top-level directory at once.
Example (Sync):
```python
@env.task
def count_files_sync(d: Dir) -> int:
files = d.list_files_sync()
return len(files)
```
Example (Sync - Process files):
```python
@env.task
def process_all_files_sync(d: Dir) -> list[str]:
files = d.list_files_sync()
contents = []
for file in files:
with file.open_sync("rb") as f:
content = f.read()
contents.append(content.decode("utf-8"))
return contents
```
**Returns**
A list of File objects for files in the top-level directory
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | The context. |
### new_remote()
```python
def new_remote(
dir_name: Optional[str],
hash: Optional[str],
) -> Dir[T]
```
Create a new Dir reference for a remote directory that will be written to.
Use this when you want to create a new directory and write files into it
directly without creating a local directory first.
Example::
@env.task
async def create() -> Dir:
d = Dir.new_remote("output")
# write files into d ...
return d
| Parameter | Type | Description |
|-|-|-|
| `dir_name` | `Optional[str]` | Optional name for the remote directory. If not set, a generated name will be used. |
| `hash` | `Optional[str]` | Optional precomputed hash value to use for cache key computation when this Dir is used as an input to discoverable tasks. |
**Returns:** A new Dir instance with a generated remote path.
### pre_init()
```python
def pre_init(
data,
)
```
Internal: Pydantic validator to set default name from path. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `data` | | |
### schema_match()
```python
def schema_match(
incoming: dict,
)
```
Internal: Check if incoming schema matches Dir schema. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `incoming` | `dict` | |
### walk()
```python
def walk(
recursive: bool,
max_depth: Optional[int],
) -> AsyncIterator[File[T]]
```
Asynchronously walk through the directory and yield File objects.
Use this to iterate through all files in a directory. Each yielded File can be read directly without
downloading.
Example (Async - Recursive):
```python
@env.task
async def list_all_files(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=True):
file_names.append(file.name)
return file_names
```
Example (Async - Non-recursive):
```python
@env.task
async def list_top_level_files(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=False):
file_names.append(file.name)
return file_names
```
Example (Async - With max depth):
```python
@env.task
async def list_files_max_depth(d: Dir) -> list[str]:
file_names = []
async for file in d.walk(recursive=True, max_depth=2):
file_names.append(file.name)
return file_names
```
Yields:
File objects for each file found in the directory
| Parameter | Type | Description |
|-|-|-|
| `recursive` | `bool` | If True, recursively walk subdirectories. If False, only list files in the top-level directory. |
| `max_depth` | `Optional[int]` | Maximum depth for recursive walking. If None, walk through all subdirectories. |
### walk_sync()
```python
def walk_sync(
recursive: bool,
file_pattern: str,
max_depth: Optional[int],
) -> Iterator[File[T]]
```
Synchronously walk through the directory and yield File objects.
Use this in non-async tasks to iterate through all files in a directory.
Example (Sync - Recursive):
```python
@env.task
def list_all_files_sync(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True):
file_names.append(file.name)
return file_names
```
Example (Sync - With file pattern):
```python
@env.task
def list_text_files(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True, file_pattern="*.txt"):
file_names.append(file.name)
return file_names
```
Example (Sync - Non-recursive with max depth):
```python
@env.task
def list_files_limited(d: Dir) -> list[str]:
file_names = []
for file in d.walk_sync(recursive=True, max_depth=2):
file_names.append(file.name)
return file_names
```
Yields:
File objects for each file found in the directory
| Parameter | Type | Description |
|-|-|-|
| `recursive` | `bool` | If True, recursively walk subdirectories. If False, only list files in the top-level directory. |
| `file_pattern` | `str` | Glob pattern to filter files (e.g., "*.txt", "*.csv"). Default is "*" (all files). |
| `max_depth` | `Optional[int]` | Maximum depth for recursive walking. If None, walk through all subdirectories. |
### writer()
```python
def writer(
shard_extension: str,
max_records_per_shard: int | None,
max_bytes_per_shard: int,
flush_bytes: int,
compression_level: int,
) -> AsyncGenerator[JsonlDirWriter, None]
```
Async context manager returning a `JsonlDirWriter`.
Scans the directory for existing shards and starts writing from the
next available index, so appending to an existing directory is safe.
| Parameter | Type | Description |
|-|-|-|
| `shard_extension` | `str` | File extension (e.g. `.jsonl` or `.jsonl.zst`). |
| `max_records_per_shard` | `int \| None` | Roll after this many records (None = no limit). |
| `max_bytes_per_shard` | `int` | Roll after this many uncompressed bytes (default 256 MB). |
| `flush_bytes` | `int` | Buffer flush threshold in bytes (default 1 MB). |
| `compression_level` | `int` | Zstd level (default 3, only for `.jsonl.zst`). |
### writer_sync()
```python
def writer_sync(
shard_extension: str,
max_records_per_shard: int | None,
max_bytes_per_shard: int,
flush_bytes: int,
compression_level: int,
) -> Generator[JsonlDirWriterSync, None, None]
```
Sync context manager returning a `JsonlDirWriterSync`.
See `writer` for argument descriptions.
| Parameter | Type | Description |
|-|-|-|
| `shard_extension` | `str` | |
| `max_records_per_shard` | `int \| None` | |
| `max_bytes_per_shard` | `int` | |
| `flush_bytes` | `int` | |
| `compression_level` | `int` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/jsonl/packages/flyteplugins.jsonl/jsonlfile ===
# JsonlFile
**Package:** `flyteplugins.jsonl`
A file type for JSONL (JSON Lines) files, backed by `orjson` for fast
serialisation.
Provides streaming read and write methods that process one record at a time
without loading the entire file into memory. Inherits all `File`
capabilities (remote storage, upload/download, etc.).
Supports zstd-compressed files transparently via extension detection
(`.jsonl.zst` / `.jsonl.zstd`).
Example (Async read - compressed or uncompressed):
```python
@env.task
async def process(f: JsonlFile):
async for record in f.iter_records():
print(record)
```
Example (Async write - compressed or uncompressed):
```python
@env.task
async def create() -> JsonlFile:
f = JsonlFile.new_remote("data.jsonl")
async with f.writer() as w:
await w.write({"key": "value"})
return f
```
Example (Sync write - compressed or uncompressed):
```python
@env.task
def create() -> JsonlFile:
f = JsonlFile.new_remote("data.jsonl")
with f.writer_sync() as w:
w.write({"key": "value"})
return f
```
## Parameters
```python
class JsonlFile(
path: str,
name: typing.Optional[str],
format: str,
hash: typing.Optional[str],
hash_method: typing.Optional[flyte.io._hashing_io.HashMethod],
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.ValidationError) if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `name` | `typing.Optional[str]` | |
| `format` | `str` | |
| `hash` | `typing.Optional[str]` | |
| `hash_method` | `typing.Optional[flyte.io._hashing_io.HashMethod]` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `lazy_uploader` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > download()** | Asynchronously download the file to a local path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > download_sync()** | Synchronously download the file to a local path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > exists()** | Asynchronously check if the file exists. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > exists_sync()** | Synchronously check if the file exists. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > from_existing_remote()** | Create a File reference from an existing remote file. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > from_local()** | Asynchronously create a new File object from a local file by uploading it to remote storage. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > from_local_sync()** | Synchronously create a new File object from a local file by uploading it to remote storage. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > iter_arrow_batches()** | Stream JSONL as Arrow RecordBatches. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > iter_arrow_batches_sync()** | Sync generator that yields Arrow RecordBatches. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > iter_records()** | Async generator that yields parsed dicts line by line. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > iter_records_sync()** | Sync generator that yields parsed dicts line by line. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > model_post_init()** | This function is meant to behave like a BaseModel method to initialise private attributes. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > named_remote()** | Create a File reference whose remote path is derived deterministically from *name*. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > new_remote()** | Create a new File reference for a remote file that will be written to. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > open()** | Asynchronously open the file and return a file-like object. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > open_sync()** | Synchronously open the file and return a file-like object. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > pre_init()** | Internal: Pydantic validator to set default name from path. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > schema_match()** | Internal: Check if incoming schema matches File schema. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > writer()** | Async context manager returning a `JsonlWriter` for streaming writes. |
| **Integrations > JSONL > Packages > flyteplugins.jsonl > JsonlFile > Methods > writer_sync()** | Sync context manager returning a `JsonlWriterSync` for streaming writes. |
### download()
```python
def download(
local_path: Optional[Union[str, Path]],
) -> str
```
Asynchronously download the file to a local path.
Use this when you need to download a remote file to your local filesystem for processing.
Example (Async):
```python
@env.task
async def download_and_process(f: File) -> str:
local_path = await f.download()
# Now process the local file
with open(local_path, "r") as fh:
return fh.read()
```
Example (Download to specific path):
```python
@env.task
async def download_to_path(f: File) -> str:
local_path = await f.download("/tmp/myfile.csv")
return local_path
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the file to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded file
### download_sync()
```python
def download_sync(
local_path: Optional[Union[str, Path]],
) -> str
```
Synchronously download the file to a local path.
Use this in non-async tasks when you need to download a remote file to your local filesystem.
Example (Sync):
```python
@env.task
def download_and_process_sync(f: File) -> str:
local_path = f.download_sync()
# Now process the local file
with open(local_path, "r") as fh:
return fh.read()
```
Example (Download to specific path):
```python
@env.task
def download_to_path_sync(f: File) -> str:
local_path = f.download_sync("/tmp/myfile.csv")
return local_path
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Optional[Union[str, Path]]` | The local path to download the file to. If None, a temporary directory will be used and a path will be generated. |
**Returns:** The absolute path to the downloaded file
### exists()
```python
def exists()
```
Asynchronously check if the file exists.
Example (Async):
```python
@env.task
async def check_file(f: File) -> bool:
if await f.exists():
print("File exists!")
return True
return False
```
**Returns:** True if the file exists, False otherwise
### exists_sync()
```python
def exists_sync()
```
Synchronously check if the file exists.
Use this in non-async tasks or when you need synchronous file existence checking.
Example (Sync):
```python
@env.task
def check_file_sync(f: File) -> bool:
if f.exists_sync():
print("File exists!")
return True
return False
```
**Returns:** True if the file exists, False otherwise
### from_existing_remote()
```python
def from_existing_remote(
remote_path: str,
file_cache_key: Optional[str],
) -> File[T]
```
Create a File reference from an existing remote file.
Use this when you want to reference a file that already exists in remote storage without uploading it.
Example:
```python
@env.task
async def process_existing_file() -> str:
file = File.from_existing_remote("s3://my-bucket/data.csv")
async with file.open("rb") as f:
content = await f.read()
return content.decode("utf-8")
```
| Parameter | Type | Description |
|-|-|-|
| `remote_path` | `str` | The remote path to the existing file |
| `file_cache_key` | `Optional[str]` | Optional hash value to use for cache key computation. If not specified, the cache key will be computed based on the file's attributes (path, name, format). |
**Returns:** A new File instance pointing to the existing remote file
### from_local()
```python
def from_local(
local_path: Union[str, Path],
remote_destination: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Asynchronously create a new File object from a local file by uploading it to remote storage.
Use this in async tasks when you have a local file that needs to be uploaded to remote storage.
Example (Async):
```python
@env.task
async def upload_local_file() -> File:
# Create a local file
async with aiofiles.open("/tmp/data.csv", "w") as f:
await f.write("col1,col2
# Upload to remote storage
remote_file = await File.from_local("/tmp/data.csv")
return remote_file
```
Example (With specific destination):
```python
@env.task
async def upload_to_specific_path() -> File:
remote_file = await File.from_local("/tmp/data.csv", "s3://my-bucket/data.csv")
return remote_file
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local file |
| `remote_destination` | `Optional[str]` | Optional remote path to store the file. If None, a path will be automatically generated. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash during upload. If not specified, the cache key will be based on file attributes. |
**Returns**
A new File instance pointing to the uploaded remote file
### from_local_sync()
```python
def from_local_sync(
local_path: Union[str, Path],
remote_destination: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Synchronously create a new File object from a local file by uploading it to remote storage.
Use this in non-async tasks when you have a local file that needs to be uploaded to remote storage.
Example (Sync):
```python
@env.task
def upload_local_file_sync() -> File:
# Create a local file
with open("/tmp/data.csv", "w") as f:
f.write("col1,col2
# Upload to remote storage
remote_file = File.from_local_sync("/tmp/data.csv")
return remote_file
```
Example (With specific destination):
```python
@env.task
def upload_to_specific_path() -> File:
remote_file = File.from_local_sync("/tmp/data.csv", "s3://my-bucket/data.csv")
return remote_file
```
| Parameter | Type | Description |
|-|-|-|
| `local_path` | `Union[str, Path]` | Path to the local file |
| `remote_destination` | `Optional[str]` | Optional remote path to store the file. If None, a path will be automatically generated. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will compute the hash during upload. If not specified, the cache key will be based on file attributes. |
**Returns**
A new File instance pointing to the uploaded remote file
### iter_arrow_batches()
```python
def iter_arrow_batches(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> AsyncGenerator[Any, None]
```
Stream JSONL as Arrow RecordBatches.
Memory usage is bounded by batch_size.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | |
### iter_arrow_batches_sync()
```python
def iter_arrow_batches_sync(
batch_size: int,
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> Generator[Any, None, None]
```
Sync generator that yields Arrow RecordBatches.
Memory usage is bounded by batch_size.
| Parameter | Type | Description |
|-|-|-|
| `batch_size` | `int` | |
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | |
### iter_records()
```python
def iter_records(
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> AsyncGenerator[dict[str, Any], None]
```
Async generator that yields parsed dicts line by line.
| Parameter | Type | Description |
|-|-|-|
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | |
### iter_records_sync()
```python
def iter_records_sync(
on_error: Literal['raise', 'skip'] | ErrorHandler,
) -> Generator[dict[str, Any], None, None]
```
Sync generator that yields parsed dicts line by line.
| Parameter | Type | Description |
|-|-|-|
| `on_error` | `Literal['raise', 'skip'] \| ErrorHandler` | |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | The context. |
### named_remote()
```python
def named_remote(
name: str,
) -> File[T]
```
Create a File reference whose remote path is derived deterministically from *name*.
Unlike `new_remote`, which generates a random path on every call, this method
produces the same path for the same *name* within a given task execution. This makes
it safe across retries: the first attempt uploads to the path and subsequent retries
resolve to the identical location without re-uploading.
The path is optionally namespaced by the node ID extracted from the backend
raw-data path, which follows the convention:
{run_name}-{node_id}-{attempt_index}
If extraction fails, the function falls back to the run base directory alone.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | Plain filename (e.g., "data.csv"). Must not contain path separators. |
**Returns:** A `File` instance whose path is stable across retries.
### new_remote()
```python
def new_remote(
file_name: Optional[str],
hash_method: Optional[HashMethod | str],
) -> File[T]
```
Create a new File reference for a remote file that will be written to.
Use this when you want to create a new file and write to it directly without creating a local file first.
Example (Async):
```python
@env.task
async def create_csv() -> File:
df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
file = File.new_remote()
async with file.open("wb") as f:
df.to_csv(f)
return file
```
| Parameter | Type | Description |
|-|-|-|
| `file_name` | `Optional[str]` | Optional string specifying a remote file name. If not set, a generated file name will be returned. |
| `hash_method` | `Optional[HashMethod \| str]` | Optional HashMethod or string to use for cache key computation. If a string is provided, it will be used as a precomputed cache key. If a HashMethod is provided, it will be used to compute the hash as data is written. |
**Returns:** A new File instance with a generated remote path
### open()
```python
def open(
mode: str,
block_size: Optional[int],
cache_type: str,
cache_options: Optional[dict],
compression: Optional[str],
kwargs,
) -> AsyncGenerator[Union[AsyncWritableFile, AsyncReadableFile, 'HashingWriter'], None]
```
Asynchronously open the file and return a file-like object.
Use this method in async tasks to read from or write to files directly.
Example (Async Read):
```python
@env.task
async def read_file(f: File) -> str:
async with f.open("rb") as fh:
content = bytes(await fh.read())
return content.decode("utf-8")
```
Example (Async Write):
```python
@env.task
async def write_file() -> File:
f = File.new_remote()
async with f.open("wb") as fh:
await fh.write(b"Hello, World!")
return f
```
Example (Streaming Read):
```python
@env.task
async def stream_read(f: File) -> str:
content_parts = []
async with f.open("rb", block_size=1024) as fh:
while True:
chunk = await fh.read()
if not chunk:
break
content_parts.append(chunk)
return b"".join(content_parts).decode("utf-8")
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `str` | The mode to open the file in (default: 'rb'). Common modes: 'rb' (read binary), 'wb' (write binary), 'rt' (read text), 'wt' (write text) |
| `block_size` | `Optional[int]` | Size of blocks for reading in bytes. Useful for streaming large files. |
| `cache_type` | `str` | Caching mechanism to use ('readahead', 'mmap', 'bytes', 'none') |
| `cache_options` | `Optional[dict]` | Dictionary of options for the cache |
| `compression` | `Optional[str]` | Compression format or None for auto-detection |
| `kwargs` | `**kwargs` | |
**Returns:** An async file-like object that can be used with async read/write operations
### open_sync()
```python
def open_sync(
mode: str,
block_size: Optional[int],
cache_type: str,
cache_options: Optional[dict],
compression: Optional[str],
kwargs,
) -> Generator[IO[Any], None, None]
```
Synchronously open the file and return a file-like object.
Use this method in non-async tasks to read from or write to files directly.
Example (Sync Read):
```python
@env.task
def read_file_sync(f: File) -> str:
with f.open_sync("rb") as fh:
content = fh.read()
return content.decode("utf-8")
```
Example (Sync Write):
```python
@env.task
def write_file_sync() -> File:
f = File.new_remote()
with f.open_sync("wb") as fh:
fh.write(b"Hello, World!")
return f
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `str` | The mode to open the file in (default: 'rb'). Common modes: 'rb' (read binary), 'wb' (write binary), 'rt' (read text), 'wt' (write text) |
| `block_size` | `Optional[int]` | Size of blocks for reading in bytes. Useful for streaming large files. |
| `cache_type` | `str` | Caching mechanism to use ('readahead', 'mmap', 'bytes', 'none') |
| `cache_options` | `Optional[dict]` | Dictionary of options for the cache |
| `compression` | `Optional[str]` | Compression format or None for auto-detection |
| `kwargs` | `**kwargs` | |
**Returns:** A file-like object that can be used with standard read/write operations
### pre_init()
```python
def pre_init(
data,
)
```
Internal: Pydantic validator to set default name from path. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `data` | | |
### schema_match()
```python
def schema_match(
incoming: dict,
)
```
Internal: Check if incoming schema matches File schema. Not intended for direct use.
| Parameter | Type | Description |
|-|-|-|
| `incoming` | `dict` | |
### writer()
```python
def writer(
flush_bytes: int,
compression_level: int,
) -> AsyncGenerator[JsonlWriter, None]
```
Async context manager returning a `JsonlWriter` for streaming writes.
If the file path ends in `.jsonl.zst`, output is zstd-compressed.
| Parameter | Type | Description |
|-|-|-|
| `flush_bytes` | `int` | Buffer flush threshold in bytes (default 1 MB). |
| `compression_level` | `int` | Zstd compression level (default 3). Only used for `.jsonl.zst` paths. Higher = smaller files, slower writes. |
### writer_sync()
```python
def writer_sync(
flush_bytes: int,
compression_level: int,
) -> Generator[JsonlWriterSync, None, None]
```
Sync context manager returning a `JsonlWriterSync` for streaming writes.
If the file path ends in `.jsonl.zst`, output is zstd-compressed.
| Parameter | Type | Description |
|-|-|-|
| `flush_bytes` | `int` | Buffer flush threshold in bytes (default 1 MB). |
| `compression_level` | `int` | Zstd compression level (default 3). Only used for `.jsonl.zst` paths. Higher = smaller files, slower writes. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/mlflow ===
# MLflow
## Subpages
- **Integrations > MLflow > Classes**
- **Integrations > MLflow > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/mlflow/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Mlflow** |MLflow UI link for Flyte tasks. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/mlflow/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > MLflow > Packages > flyteplugins.mlflow** | ## Key features:. |
## Subpages
- **Integrations > MLflow > Packages > flyteplugins.mlflow**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/mlflow/packages/flyteplugins.mlflow ===
# flyteplugins.mlflow
## Key features:
- Automatic MLflow run management with `@mlflow_run` decorator
- Built-in autologging support via `autolog=True` parameter
- Auto-generated MLflow UI links via `link_host` config and the `Mlflow` link class
- Parent/child task support with run sharing
- Distributed training support (only rank 0 logs to MLflow)
- Configuration management with `mlflow_config()`
## Basic usage:
1. Manual logging with `@mlflow_run`:
```python
from flyteplugins.mlflow import mlflow_run, get_mlflow_run
@mlflow_run(
tracking_uri="http://localhost:5000",
experiment_name="my-experiment",
tags={"team": "ml"},
)
@env.task
async def train_model(learning_rate: float) -> str:
import mlflow
mlflow.log_param("lr", learning_rate)
mlflow.log_metric("loss", 0.5)
run = get_mlflow_run()
return run.info.run_id
```
2. Automatic logging with `@mlflow_run(autolog=True)`:
```python
from flyteplugins.mlflow import mlflow_run
@mlflow_run(
autolog=True,
framework="sklearn",
tracking_uri="http://localhost:5000",
log_models=True,
log_datasets=False,
experiment_id="846992856162999",
)
@env.task
async def train_sklearn_model():
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y) # Autolog captures parameters, metrics, and model
```
3. Workflow-level configuration with `mlflow_config()`:
```python
from flyteplugins.mlflow import mlflow_config
r = flyte.with_runcontext(
custom_context=mlflow_config(
tracking_uri="http://localhost:5000",
experiment_id="846992856162999",
tags={"team": "ml"},
)
).run(train_model, learning_rate=0.001)
```
4. Per-task config overrides with context manager:
```python
@mlflow_run
@env.task
async def parent_task():
# Override config for a specific child task
with mlflow_config(run_mode="new", tags={"role": "child"}):
await child_task()
```
5. Run modes β control run creation vs sharing:
```python
@mlflow_run # "auto": new run if no parent, else share parent's
@mlflow_run(run_mode="new") # Always create a new run
```
6. HPO β objective can be a Flyte task with `run_mode="new"`:
```python
@mlflow_run(run_mode="new")
@env.task
def objective(params: dict) -> float:
mlflow.log_params(params)
loss = train(params)
mlflow.log_metric("loss", loss)
return loss
```
7. Distributed training (only rank 0 logs):
```python
@mlflow_run # Auto-detects rank from RANK env var
@env.task
async def distributed_train():
...
```
8. MLflow UI links β auto-generated via `link_host`:
```python
from flyteplugins.mlflow import Mlflow, mlflow_config
# Set link_host at workflow level β children with Mlflow() link
# auto-get the URL after the parent creates the run.
r = flyte.with_runcontext(
custom_context=mlflow_config(
tracking_uri="http://localhost:5000",
link_host="http://localhost:5000",
)
).run(parent_task)
# Attach the link to child tasks:
@mlflow_run
@env.task(links=[Mlflow()])
async def child_task(): ...
# Custom URL template (e.g. Databricks):
mlflow_config(
link_host="https://dbc-xxx.cloud.databricks.com",
link_template="{host}/ml/experiments/{experiment_id}/runs/{run_id}",
)
```
Decorator order: `@mlflow_run` must be outermost (before `@env.task`):
```python
@mlflow_run
@env.task
async def my_task(): ...
@mlflow_run(autolog=True, framework="sklearn")
@env.task
async def my_task(): ...
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Mlflow** | MLflow UI link for Flyte tasks. |
### Methods
| Method | Description |
|-|-|
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Methods > get_mlflow_context()** | Retrieve current MLflow configuration from Flyte context. |
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Methods > get_mlflow_run()** | Get the current MLflow run if within a `@mlflow_run` decorated task or trace. |
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Methods > mlflow_config()** | Create MLflow configuration. |
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Methods > mlflow_run()** | Decorator to manage MLflow runs for Flyte tasks and plain functions. |
## Methods
#### get_mlflow_context()
```python
def get_mlflow_context()
```
Retrieve current MLflow configuration from Flyte context.
#### get_mlflow_run()
```python
def get_mlflow_run()
```
Get the current MLflow run if within a `@mlflow_run` decorated task or trace.
The run is started when the `@mlflow_run` decorator enters.
Returns None if not within an `mlflow_run` context.
**Returns:** `mlflow.ActiveRun` | `None`: The current MLflow active run or None.
#### mlflow_config()
```python
def mlflow_config(
tracking_uri: typing.Optional[str],
experiment_name: typing.Optional[str],
experiment_id: typing.Optional[str],
run_name: typing.Optional[str],
run_id: typing.Optional[str],
tags: typing.Optional[dict[str, str]],
run_mode: typing.Literal['auto', 'new', 'nested'],
autolog: bool,
framework: typing.Optional[str],
log_models: typing.Optional[bool],
log_datasets: typing.Optional[bool],
autolog_kwargs: typing.Optional[dict[str, typing.Any]],
link_host: typing.Optional[str],
link_template: typing.Optional[str],
kwargs: **kwargs,
) -> flyteplugins.mlflow._context._MLflowConfig
```
Create MLflow configuration.
Works in two contexts:
1. With `flyte.with_runcontext()` for global configuration
2. As a context manager to override configuration
| Parameter | Type | Description |
|-|-|-|
| `tracking_uri` | `typing.Optional[str]` | MLflow tracking server URI. |
| `experiment_name` | `typing.Optional[str]` | MLflow experiment name. |
| `experiment_id` | `typing.Optional[str]` | MLflow experiment ID. |
| `run_name` | `typing.Optional[str]` | Human-readable run name. |
| `run_id` | `typing.Optional[str]` | Explicit MLflow run ID. |
| `tags` | `typing.Optional[dict[str, str]]` | MLflow run tags. |
| `run_mode` | `typing.Literal['auto', 'new', 'nested']` | Flyte-specific run mode ("auto", "new", "nested"). |
| `autolog` | `bool` | Enable MLflow autologging. |
| `framework` | `typing.Optional[str]` | Framework-specific autolog (e.g. "sklearn", "pytorch"). |
| `log_models` | `typing.Optional[bool]` | Whether to log models automatically. |
| `log_datasets` | `typing.Optional[bool]` | Whether to log datasets automatically. |
| `autolog_kwargs` | `typing.Optional[dict[str, typing.Any]]` | Extra parameters passed to mlflow.autolog(). |
| `link_host` | `typing.Optional[str]` | MLflow UI host for auto-generating task links. |
| `link_template` | `typing.Optional[str]` | Custom URL template. Defaults to standard MLflow UI format. Available placeholders: `{host}`, `{experiment_id}`, `{run_id}`. |
| `kwargs` | `**kwargs` | |
#### mlflow_run()
```python
def mlflow_run(
_func: typing.Optional[~F],
run_mode: typing.Literal['auto', 'new', 'nested'],
tracking_uri: typing.Optional[str],
experiment_name: typing.Optional[str],
experiment_id: typing.Optional[str],
run_name: typing.Optional[str],
run_id: typing.Optional[str],
tags: typing.Optional[dict[str, str]],
autolog: bool,
framework: typing.Optional[str],
log_models: typing.Optional[bool],
log_datasets: typing.Optional[bool],
autolog_kwargs: typing.Optional[dict[str, typing.Any]],
rank: typing.Optional[int],
kwargs,
) -> ~F
```
Decorator to manage MLflow runs for Flyte tasks and plain functions.
Handles both manual logging and autologging. For autologging, pass
`autolog=True` and optionally `framework` to select a specific
framework (e.g. `"sklearn"`).
Decorator Order:
@mlflow_run must be the outermost decorator::
@mlflow_run
@env.task
async def my_task():
...
| Parameter | Type | Description |
|-|-|-|
| `_func` | `typing.Optional[~F]` | |
| `run_mode` | `typing.Literal['auto', 'new', 'nested']` | "auto" (default), "new", or "nested". - "auto": reuse parent run if available, else create new. - "new": always create a new independent run. - "nested": create a new run nested under the parent via `mlflow.parentRunId` tag. Works across processes/containers. |
| `tracking_uri` | `typing.Optional[str]` | MLflow tracking server URL. |
| `experiment_name` | `typing.Optional[str]` | MLflow experiment name (exclusive with experiment_id). |
| `experiment_id` | `typing.Optional[str]` | MLflow experiment ID (exclusive with experiment_name). |
| `run_name` | `typing.Optional[str]` | Human-readable run name (exclusive with run_id). |
| `run_id` | `typing.Optional[str]` | MLflow run ID (exclusive with run_name). |
| `tags` | `typing.Optional[dict[str, str]]` | Dictionary of tags for the run. |
| `autolog` | `bool` | Enable MLflow autologging. |
| `framework` | `typing.Optional[str]` | MLflow framework name for autolog (e.g. "sklearn", "pytorch"). |
| `log_models` | `typing.Optional[bool]` | Whether to log models automatically (requires autolog). |
| `log_datasets` | `typing.Optional[bool]` | Whether to log datasets automatically (requires autolog). |
| `autolog_kwargs` | `typing.Optional[dict[str, typing.Any]]` | Extra parameters passed to `mlflow.autolog()`. |
| `rank` | `typing.Optional[int]` | Process rank for distributed training (only rank 0 logs). |
| `kwargs` | `**kwargs` | |
## Subpages
- **Integrations > MLflow > Packages > flyteplugins.mlflow > Mlflow**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/mlflow/packages/flyteplugins.mlflow/mlflow ===
# Mlflow
**Package:** `flyteplugins.mlflow`
MLflow UI link for Flyte tasks.
Resolves the link URL from one of two sources (in priority order):
1. **Explicit link** β set at definition or override time::
@env.task(links=[Mlflow(link="https://mlflow.example.com/...")])
task.override(links=[Mlflow(link="https://...")])()
2. **Context link** β auto-generated from `link_host` (and optional
`link_template`) set via `mlflow_config()`. Propagates to child
tasks that share or nest under the parent's run. Cleared when a task
creates an independent run (`run_mode="new"`). For nested runs
(`run_mode="nested"`), the parent link is kept and the link name
is automatically set to "MLflow (parent)".
## Parameters
```python
class Mlflow(
name: str,
link: str,
_decorator_run_mode: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `link` | `str` | |
| `_decorator_run_mode` | `str` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > MLflow > Packages > flyteplugins.mlflow > Mlflow > Methods > get_link()** | Returns a task log link given the action. |
### get_link()
```python
def get_link(
run_name: str,
project: str,
domain: str,
context: dict[str, str],
parent_action_name: str,
action_name: str,
pod_name: str,
kwargs,
) -> str
```
Returns a task log link given the action.
Link can have template variables that are replaced by the backend.
| Parameter | Type | Description |
|-|-|-|
| `run_name` | `str` | The name of the run. |
| `project` | `str` | The project name. |
| `domain` | `str` | The domain name. |
| `context` | `dict[str, str]` | Additional context for generating the link. |
| `parent_action_name` | `str` | The name of the parent action. |
| `action_name` | `str` | The name of the action. |
| `pod_name` | `str` | The name of the pod. |
| `kwargs` | `**kwargs` | Additional keyword arguments. |
**Returns:** The generated link.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/openai ===
# OpenAI
## Subpages
- [flyteplugins.openai.agents](flyteplugins.openai.agents/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars ===
# Polars
## Subpages
- **Integrations > Polars > Classes**
- **Integrations > Polars > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsDecodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsLazyFrameDecodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsLazyFrameToParquetEncodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsToParquetEncodingHandler** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer** | |
## Subpages
- **Integrations > Polars > Packages > flyteplugins.polars.df_transformer**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages/flyteplugins.polars.df_transformer ===
# flyteplugins.polars.df_transformer
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsDecodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsLazyFrameDecodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsLazyFrameToParquetEncodingHandler** | |
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsToParquetEncodingHandler** | |
### Methods
| Method | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > Methods > get_polars_storage_options()** | Get storage options in a format compatible with Polars. |
### Variables
| Property | Type | Description |
|-|-|-|
| `PARQUET` | `str` | |
## Methods
#### get_polars_storage_options()
```python
def get_polars_storage_options(
protocol: typing.Optional[str],
anonymous: bool,
) -> typing.Dict[str, str]
```
Get storage options in a format compatible with Polars.
Polars requires storage_options to be a flat dict with string keys and values,
unlike fsspec which accepts nested dicts and complex objects.
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
## Subpages
- **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsDecodingHandler**
- **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsLazyFrameDecodingHandler**
- **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsLazyFrameToParquetEncodingHandler**
- **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsToParquetEncodingHandler**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages/flyteplugins.polars.df_transformer/parquettopolarsdecodinghandler ===
# ParquetToPolarsDecodingHandler
**Package:** `flyteplugins.polars.df_transformer`
## Parameters
```python
def ParquetToPolarsDecodingHandler()
```
Extend this abstract class, implement the decode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
and we have to get a Python value out of it. For the other way, see the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsDecodingHandler > Methods > decode()** | This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal. |
### decode()
```python
def decode(
flyte_value: flyteidl2.core.literals_pb2.StructuredDataset,
current_task_metadata: flyteidl2.core.literals_pb2.StructuredDatasetMetadata,
) -> pl.DataFrame
```
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.
of those dataframes.
| Parameter | Type | Description |
|-|-|-|
| `flyte_value` | `flyteidl2.core.literals_pb2.StructuredDataset` | This will be a Flyte IDL DataFrame Literal - do not confuse this with the DataFrame class defined also in this module. |
| `current_task_metadata` | `flyteidl2.core.literals_pb2.StructuredDatasetMetadata` | Metadata object containing the type (and columns if any) for the currently executing task. This type may have more or less information than the type information bundled inside the incoming flyte_value. |
**Returns:** This function can either return an instance of the dataframe that this decoder handles, or an iterator
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages/flyteplugins.polars.df_transformer/parquettopolarslazyframedecodinghandler ===
# ParquetToPolarsLazyFrameDecodingHandler
**Package:** `flyteplugins.polars.df_transformer`
## Parameters
```python
def ParquetToPolarsLazyFrameDecodingHandler()
```
Extend this abstract class, implement the decode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
and we have to get a Python value out of it. For the other way, see the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > ParquetToPolarsLazyFrameDecodingHandler > Methods > decode()** | This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal. |
### decode()
```python
def decode(
flyte_value: flyteidl2.core.literals_pb2.StructuredDataset,
current_task_metadata: flyteidl2.core.literals_pb2.StructuredDatasetMetadata,
) -> pl.LazyFrame
```
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.
of those dataframes.
| Parameter | Type | Description |
|-|-|-|
| `flyte_value` | `flyteidl2.core.literals_pb2.StructuredDataset` | This will be a Flyte IDL DataFrame Literal - do not confuse this with the DataFrame class defined also in this module. |
| `current_task_metadata` | `flyteidl2.core.literals_pb2.StructuredDatasetMetadata` | Metadata object containing the type (and columns if any) for the currently executing task. This type may have more or less information than the type information bundled inside the incoming flyte_value. |
**Returns:** This function can either return an instance of the dataframe that this decoder handles, or an iterator
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages/flyteplugins.polars.df_transformer/polarslazyframetoparquetencodinghandler ===
# PolarsLazyFrameToParquetEncodingHandler
**Package:** `flyteplugins.polars.df_transformer`
## Parameters
```python
def PolarsLazyFrameToParquetEncodingHandler()
```
Extend this abstract class, implement the encode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsLazyFrameToParquetEncodingHandler > Methods > encode()** | Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the. |
### encode()
```python
def encode(
dataframe: flyte.io._dataframe.dataframe.DataFrame,
structured_dataset_type: flyteidl2.core.types_pb2.StructuredDatasetType,
) -> flyteidl2.core.literals_pb2.StructuredDataset
```
Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the
incoming dataframe with defaults set for that dataframe
type. This simplifies this function's interface as a lot of data that could be specified by the user using
the
# TODO: Do we need to add a flag to indicate if it was wrapped by the transformer or by the user?
DataFrame wrapper class used as input to this function - that is the user facing Python class.
This function needs to return the IDL DataFrame.
| Parameter | Type | Description |
|-|-|-|
| `dataframe` | `flyte.io._dataframe.dataframe.DataFrame` | This is a DataFrame wrapper object. See more info above. |
| `structured_dataset_type` | `flyteidl2.core.types_pb2.StructuredDatasetType` | This the DataFrameType, as found in the LiteralType of the interface of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders can include it in the returned literals.DataFrame. See the IDL for more information on why this literal in particular carries the type information along with it. If the encoder doesn't supply it, it will also be filled in after the encoder runs by the transformer engine. |
**Returns:** This function should return a DataFrame literal object. Do not confuse this with the
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/polars/packages/flyteplugins.polars.df_transformer/polarstoparquetencodinghandler ===
# PolarsToParquetEncodingHandler
**Package:** `flyteplugins.polars.df_transformer`
## Parameters
```python
def PolarsToParquetEncodingHandler()
```
Extend this abstract class, implement the encode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Polars > Packages > flyteplugins.polars.df_transformer > PolarsToParquetEncodingHandler > Methods > encode()** | Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the. |
### encode()
```python
def encode(
dataframe: flyte.io._dataframe.dataframe.DataFrame,
structured_dataset_type: flyteidl2.core.types_pb2.StructuredDatasetType,
) -> flyteidl2.core.literals_pb2.StructuredDataset
```
Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the
incoming dataframe with defaults set for that dataframe
type. This simplifies this function's interface as a lot of data that could be specified by the user using
the
# TODO: Do we need to add a flag to indicate if it was wrapped by the transformer or by the user?
DataFrame wrapper class used as input to this function - that is the user facing Python class.
This function needs to return the IDL DataFrame.
| Parameter | Type | Description |
|-|-|-|
| `dataframe` | `flyte.io._dataframe.dataframe.DataFrame` | This is a DataFrame wrapper object. See more info above. |
| `structured_dataset_type` | `flyteidl2.core.types_pb2.StructuredDatasetType` | This the DataFrameType, as found in the LiteralType of the interface of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders can include it in the returned literals.DataFrame. See the IDL for more information on why this literal in particular carries the type information along with it. If the encoder doesn't supply it, it will also be filled in after the encoder runs by the transformer engine. |
**Returns:** This function should return a DataFrame literal object. Do not confuse this with the
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/pytorch ===
# PyTorch
## Subpages
- **Integrations > PyTorch > Classes**
- **Integrations > PyTorch > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/pytorch/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > PyTorch > Packages > flyteplugins.pytorch > Elastic** |Elastic defines the configuration for running a PyTorch elastic job using torch. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/pytorch/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > PyTorch > Packages > flyteplugins.pytorch** | |
## Subpages
- **Integrations > PyTorch > Packages > flyteplugins.pytorch**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/pytorch/packages/flyteplugins.pytorch ===
# flyteplugins.pytorch
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > PyTorch > Packages > flyteplugins.pytorch > Elastic** | Elastic defines the configuration for running a PyTorch elastic job using torch. |
## Subpages
- **Integrations > PyTorch > Packages > flyteplugins.pytorch > Elastic**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/pytorch/packages/flyteplugins.pytorch/elastic ===
# Elastic
**Package:** `flyteplugins.pytorch`
Elastic defines the configuration for running a PyTorch elastic job using torch.distributed.
When a worker fails (e.g. CUDA OOM), the elastic agent detects the failure and
restarts all workers as a group. Each restart cycle has a cost determined by the
NCCL timeout settings below. The total worst-case time before the job fails is::
(max_restarts + 1) * (nccl_collective_timeout_sec + nccl_heartbeat_timeout_sec)
For example, with defaults (max_restarts=3, collective=600s, heartbeat=300s):
4 * 900s = 60 min. With aggressive settings (max_restarts=0, collective=60s,
heartbeat=60s): 1 * 120s = 2 min.
## Parameters
```python
class Elastic(
nnodes: typing.Union[int, str],
nproc_per_node: int,
rdzv_backend: typing.Literal['c10d', 'etcd', 'etcd-v2'],
run_policy: typing.Optional[flyteplugins.pytorch.task.RunPolicy],
monitor_interval: int,
max_restarts: int,
rdzv_configs: typing.Dict[str, typing.Any],
nccl_heartbeat_timeout_sec: typing.Optional[int],
nccl_async_error_handling: bool,
nccl_collective_timeout_sec: typing.Optional[int],
nccl_enable_monitoring: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `nnodes` | `typing.Union[int, str]` | Number of nodes to use. Can be a fixed int or a range string (e.g., "2:4" for elastic training). |
| `nproc_per_node` | `int` | Number of processes to launch per node. |
| `rdzv_backend` | `typing.Literal['c10d', 'etcd', 'etcd-v2']` | Rendezvous backend to use. Typically "c10d". Defaults to "c10d". |
| `run_policy` | `typing.Optional[flyteplugins.pytorch.task.RunPolicy]` | Run policy applied to the job execution. Defaults to None. |
| `monitor_interval` | `int` | Interval (in seconds) the elastic agent polls worker process health. Once a worker process exits, detection takes at most this long. Defaults to 3. |
| `max_restarts` | `int` | Maximum number of worker group restarts before the elastic agent gives up and raises `ChildFailedError`. Each restart kills all workers and relaunches the entire group. If the failure is deterministic (e.g. model too large for GPU memory), restarts just repeat the same failure β set to 0 to fail immediately. Use higher values for transient failures (e.g. spot instance preemption, occasional OOM from variable batch sizes). Defaults to 3. |
| `rdzv_configs` | `typing.Dict[str, typing.Any]` | Rendezvous configuration key-value pairs. Defaults to {"timeout": 900, "join_timeout": 900}. |
| `nccl_heartbeat_timeout_sec` | `typing.Optional[int]` | Timeout in seconds for the NCCL heartbeat monitor thread. After the collective timeout fires and the NCCL watchdog aborts the communicator, the heartbeat monitor waits this long before sending SIGABRT to kill the worker process. This is the second phase of failure detection β it converts a stuck NCCL abort into a hard process kill. Defaults to 300 (5 min) instead of PyTorch's 1800s (30 min). Set to None to use PyTorch default. |
| `nccl_async_error_handling` | `bool` | When True, sets TORCH_NCCL_ASYNC_ERROR_HANDLING=1 so that NCCL aborts stuck collectives asynchronously instead of blocking indefinitely. This causes the worker process to crash-exit on a stuck collective, which the elastic agent detects within `monitor_interval` seconds (~3s by default) β much faster than waiting for the heartbeat timeout. Defaults to False (PyTorch default behavior). |
| `nccl_collective_timeout_sec` | `typing.Optional[int]` | Timeout in seconds for individual NCCL collective operations (e.g. all-reduce inside loss.backward()). This is the timeout passed to `torch.distributed.init_process_group`. When a worker desyncs (e.g. skips a collective after OOM), surviving workers block in the collective for this long before the NCCL watchdog fires. This is the first phase of failure detection. PyTorch default is 600s (10 min). Set to None to use PyTorch default. |
| `nccl_enable_monitoring` | `bool` | When True, sets TORCH_NCCL_ENABLE_MONITORING=1 to activate NCCL's built-in monitoring thread. The monitoring thread checks each worker's heartbeat counter and sends SIGABRT when it stalls, which is what drives `nccl_heartbeat_timeout_sec`. Defaults to True. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray ===
# Ray
## Subpages
- **Integrations > Ray > Classes**
- **Integrations > Ray > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Ray > Packages > flyteplugins.ray > HeadNodeConfig** | |
| **Integrations > Ray > Packages > flyteplugins.ray > RayJobConfig** | |
| **Integrations > Ray > Packages > flyteplugins.ray > WorkerNodeConfig** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Ray > Packages > flyteplugins.ray** | |
## Subpages
- **Integrations > Ray > Packages > flyteplugins.ray**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/packages/flyteplugins.ray ===
# flyteplugins.ray
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Ray > Packages > flyteplugins.ray > HeadNodeConfig** | |
| **Integrations > Ray > Packages > flyteplugins.ray > RayJobConfig** | |
| **Integrations > Ray > Packages > flyteplugins.ray > WorkerNodeConfig** | |
## Subpages
- **Integrations > Ray > Packages > flyteplugins.ray > HeadNodeConfig**
- **Integrations > Ray > Packages > flyteplugins.ray > RayJobConfig**
- **Integrations > Ray > Packages > flyteplugins.ray > WorkerNodeConfig**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/packages/flyteplugins.ray/headnodeconfig ===
# HeadNodeConfig
**Package:** `flyteplugins.ray`
## Parameters
```python
class HeadNodeConfig(
ray_start_params: typing.Optional[typing.Dict[str, str]],
pod_template: typing.Optional[flyte._pod.PodTemplate],
requests: typing.Optional[flyte._resources.Resources],
limits: typing.Optional[flyte._resources.Resources],
)
```
| Parameter | Type | Description |
|-|-|-|
| `ray_start_params` | `typing.Optional[typing.Dict[str, str]]` | |
| `pod_template` | `typing.Optional[flyte._pod.PodTemplate]` | |
| `requests` | `typing.Optional[flyte._resources.Resources]` | |
| `limits` | `typing.Optional[flyte._resources.Resources]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/packages/flyteplugins.ray/rayjobconfig ===
# RayJobConfig
**Package:** `flyteplugins.ray`
## Parameters
```python
class RayJobConfig(
worker_node_config: typing.List[flyteplugins.ray.task.WorkerNodeConfig],
head_node_config: typing.Optional[flyteplugins.ray.task.HeadNodeConfig],
enable_autoscaling: bool,
runtime_env: typing.Optional[dict],
address: typing.Optional[str],
shutdown_after_job_finishes: bool,
ttl_seconds_after_finished: typing.Optional[int],
)
```
| Parameter | Type | Description |
|-|-|-|
| `worker_node_config` | `typing.List[flyteplugins.ray.task.WorkerNodeConfig]` | |
| `head_node_config` | `typing.Optional[flyteplugins.ray.task.HeadNodeConfig]` | |
| `enable_autoscaling` | `bool` | |
| `runtime_env` | `typing.Optional[dict]` | |
| `address` | `typing.Optional[str]` | |
| `shutdown_after_job_finishes` | `bool` | |
| `ttl_seconds_after_finished` | `typing.Optional[int]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/ray/packages/flyteplugins.ray/workernodeconfig ===
# WorkerNodeConfig
**Package:** `flyteplugins.ray`
## Parameters
```python
class WorkerNodeConfig(
group_name: str,
replicas: int,
min_replicas: typing.Optional[int],
max_replicas: typing.Optional[int],
ray_start_params: typing.Optional[typing.Dict[str, str]],
pod_template: typing.Optional[flyte._pod.PodTemplate],
requests: typing.Optional[flyte._resources.Resources],
limits: typing.Optional[flyte._resources.Resources],
)
```
| Parameter | Type | Description |
|-|-|-|
| `group_name` | `str` | |
| `replicas` | `int` | |
| `min_replicas` | `typing.Optional[int]` | |
| `max_replicas` | `typing.Optional[int]` | |
| `ray_start_params` | `typing.Optional[typing.Dict[str, str]]` | |
| `pod_template` | `typing.Optional[flyte._pod.PodTemplate]` | |
| `requests` | `typing.Optional[flyte._resources.Resources]` | |
| `limits` | `typing.Optional[flyte._resources.Resources]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/sglang ===
# SGLang
## Subpages
- **Integrations > SGLang > Classes**
- **Integrations > SGLang > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/sglang/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment** |App environment backed by SGLang for serving large language models. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/sglang/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > SGLang > Packages > flyteplugins.sglang** | |
## Subpages
- **Integrations > SGLang > Packages > flyteplugins.sglang**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/sglang/packages/flyteplugins.sglang ===
# flyteplugins.sglang
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment** | App environment backed by SGLang for serving large language models. |
### Variables
| Property | Type | Description |
|-|-|-|
| `DEFAULT_SGLANG_IMAGE` | `Image` | |
## Subpages
- **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/sglang/packages/flyteplugins.sglang/sglangappenvironment ===
# SGLangAppEnvironment
**Package:** `flyteplugins.sglang`
App environment backed by SGLang for serving large language models.
This environment sets up an SGLang server with the specified model and configuration.
## Parameters
```python
class SGLangAppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
image: str | Image | Literal['auto'],
type: str,
port: int | Port,
extra_args: str | list[str],
model_path: str | RunOutput,
model_hf_path: str,
model_id: str,
stream_model: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the application. |
| `depends_on` | `List[Environment]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | Secrets that are requested for application. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables to set for the application. |
| `resources` | `Optional[Resources]` | |
| `interruptible` | `bool` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | Whether the public URL requires authentication. |
| `scaling` | `Scaling` | Scaling configuration for the app environment. |
| `domain` | `Domain \| None` | Domain to use for the app. |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `parameters` | `List[Parameter]` | |
| `cluster_pool` | `str` | The target cluster_pool where the app should be deployed. |
| `timeouts` | `Timeouts` | |
| `image` | `str \| Image \| Literal['auto']` | |
| `type` | `str` | Type of app. |
| `port` | `int \| Port` | Port application listens to. Defaults to 8000 for SGLang. |
| `extra_args` | `str \| list[str]` | Extra args to pass to `python -m sglang.launch_server`. See https://docs.sglang.io/advanced_features/server_arguments.html for details. |
| `model_path` | `str \| RunOutput` | Remote path to model (e.g., s3 |
| `model_hf_path` | `str` | Hugging Face path to model (e.g., Qwen/Qwen3-0.6B). |
| `model_id` | `str` | Model id that is exposed by SGLang. |
| `stream_model` | `bool` | Set to True to stream model from blob store to the GPU directly. If False, the model will be downloaded to the local file system first and then loaded into the GPU. |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > clone_with()** | |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > container_args()** | Return the container arguments for SGLang. |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > container_cmd()** | |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > get_port()** | |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Integrations > SGLang > Packages > flyteplugins.sglang > SGLangAppEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[list[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> SGLangAppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[list[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialization_context: SerializationContext,
) -> list[str]
```
Return the container arguments for SGLang.
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
parameter_overrides: list[Parameter] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `parameter_overrides` | `list[Parameter] \| None` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake ===
# Snowflake
## Subpages
- **Integrations > Snowflake > Classes**
- **Integrations > Snowflake > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake** | |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConfig** |Configure a Snowflake Task using a `SnowflakeConfig` object. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Snowflake > Packages > flyteplugins.snowflake** | Key features:. |
## Subpages
- **Integrations > Snowflake > Packages > flyteplugins.snowflake**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/packages/flyteplugins.snowflake ===
# flyteplugins.snowflake
Key features:
- Parameterized SQL queries with typed inputs
- Key-pair and password-based authentication
- Returns query results as DataFrames
- Automatic links to the Snowflake query dashboard in the Flyte UI
- Query cancellation on task abort
Basic usage example:
```python
import flyte
from flyte.io import DataFrame
from flyteplugins.snowflake import Snowflake, SnowflakeConfig
config = SnowflakeConfig(
account="myorg-myaccount",
user="flyte_user",
database="ANALYTICS",
schema="PUBLIC",
warehouse="COMPUTE_WH",
)
count_users = Snowflake(
name="count_users",
query_template="SELECT COUNT(*) FROM users",
plugin_config=config,
output_dataframe_type=DataFrame,
)
flyte.TaskEnvironment.from_task("snowflake_env", count_users)
if __name__ == "__main__":
flyte.init_from_config()
# Run locally (connector runs in-process, requires credentials and packages locally)
run = flyte.with_runcontext(mode="local").run(count_users)
# Run remotely (connector runs on the control plane)
run = flyte.with_runcontext(mode="remote").run(count_users)
print(run.url)
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake** | |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConfig** | Configure a Snowflake Task using a `SnowflakeConfig` object. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector** | |
## Subpages
- **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake**
- **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConfig**
- **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/packages/flyteplugins.snowflake/snowflake ===
# Snowflake
**Package:** `flyteplugins.snowflake`
## Parameters
```python
class Snowflake(
name: str,
query_template: str,
plugin_config: flyteplugins.snowflake.task.SnowflakeConfig,
inputs: typing.Optional[typing.Dict[str, typing.Type]],
output_dataframe_type: typing.Optional[typing.Type],
secret_group: typing.Optional[str],
snowflake_private_key: typing.Optional[str],
snowflake_private_key_passphrase: typing.Optional[str],
batch: bool,
kwargs,
)
```
Task to run parameterized SQL queries against Snowflake.
Note: For password authentication or other auth methods, pass them via `connection_kwargs`.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of this task. |
| `query_template` | `str` | The actual query to run. This can be parameterized using Python's printf-style string formatting with named parameters (e.g. %(param_name)s). |
| `plugin_config` | `flyteplugins.snowflake.task.SnowflakeConfig` | `SnowflakeConfig` object containing connection metadata. |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Type]]` | Name and type of inputs specified as a dictionary. |
| `output_dataframe_type` | `typing.Optional[typing.Type]` | If some data is produced by this query, then you can specify the output dataframe type. |
| `secret_group` | `typing.Optional[str]` | Optional group for secrets in the secret store. The environment variable name is auto-generated from `{secret_group}_{key}`, uppercased with hyphens replaced by underscores. If omitted, the key alone is used. |
| `snowflake_private_key` | `typing.Optional[str]` | The secret key for the Snowflake private key (key-pair auth). |
| `snowflake_private_key_passphrase` | `typing.Optional[str]` | The secret key for the private key passphrase (if encrypted). |
| `batch` | `bool` | When True, list inputs are expanded into a multi-row VALUES clause. The query_template should contain a single `VALUES (%(col)s, ...)` placeholder and each input should be a list of equal length. |
| `kwargs` | `**kwargs` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `native_interface` | `None` | |
| `source_file` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > aio()** | The aio function allows executing "sync" tasks, in an async context. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > config()** | Returns additional configuration for the task. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > container_args()** | Returns the container args for the task. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > custom_config()** | Returns additional configuration for the task. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > data_loading_config()** | This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > execute()** | |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > forward()** | Think of this as a local execute method for your task. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > override()** | Override various parameters of the task template. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > post()** | This is the postexecute function that will be. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > pre()** | This is the preexecute function that will be. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > Snowflake > Methods > sql()** | Returns the SQL for the task. |
### aio()
```python
def aio(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
tasks to be used within an asyncio parent task.
This function will also re-raise exceptions from the underlying task.
Example:
```python
@env.task
def my_legacy_task(x: int) -> int:
return x
@env.task
async def my_new_parent_task(n: int) -> List[int]:
collect = []
for x in range(n):
collect.append(my_legacy_task.aio(x))
return asyncio.gather(*collect)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### config()
```python
def config(
sctx: SerializationContext,
) -> Dict[str, str]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### container_args()
```python
def container_args(
sctx: SerializationContext,
) -> List[str]
```
Returns the container args for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### custom_config()
```python
def custom_config(
sctx: flyte.models.SerializationContext,
) -> typing.Optional[typing.Dict[str, typing.Any]]
```
Returns additional configuration for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
### data_loading_config()
```python
def data_loading_config(
sctx: SerializationContext,
) -> DataLoadingConfig
```
This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
are side-loaded in the input_path
Any outputs generated by the user container - within output_path are automatically uploaded
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `SerializationContext` | |
### execute()
```python
def execute(
kwargs,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `kwargs` | `**kwargs` | |
### forward()
```python
def forward(
args: *args,
kwargs: **kwargs,
) -> Coroutine[Any, Any, R] | R
```
Think of this as a local execute method for your task. This function will be invoked by the __call__ method
when not in a Flyte task execution context. See the implementation below for an example.
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### override()
```python
def override(
short_name: Optional[str],
resources: Optional[Resources],
cache: Optional[CacheRequest],
retries: Union[int, RetryStrategy],
timeout: Optional[TimeoutType],
reusable: Union[ReusePolicy, Literal['off'], None],
env_vars: Optional[Dict[str, str]],
secrets: Optional[SecretRequest],
max_inline_io_bytes: int | None,
pod_template: Optional[Union[str, PodTemplate]],
queue: Optional[str],
interruptible: Optional[bool],
links: Tuple[Link, ...],
kwargs: **kwargs,
) -> TaskTemplate
```
Override various parameters of the task template. This allows for dynamic configuration of the task
when it is called, such as changing the image, resources, cache policy, etc.
| Parameter | Type | Description |
|-|-|-|
| `short_name` | `Optional[str]` | Optional override for the short name of the task. |
| `resources` | `Optional[Resources]` | Optional override for the resources to use for the task. |
| `cache` | `Optional[CacheRequest]` | Optional override for the cache policy for the task. |
| `retries` | `Union[int, RetryStrategy]` | Optional override for the number of retries for the task. |
| `timeout` | `Optional[TimeoutType]` | Optional override for the timeout for the task. |
| `reusable` | `Union[ReusePolicy, Literal['off'], None]` | Optional override for the reusability policy for the task. |
| `env_vars` | `Optional[Dict[str, str]]` | Optional override for the environment variables to set for the task. |
| `secrets` | `Optional[SecretRequest]` | Optional override for the secrets that will be injected into the task at runtime. |
| `max_inline_io_bytes` | `int \| None` | Optional override for the maximum allowed size (in bytes) for all inputs and outputs passed directly to the task. |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | Optional override for the pod template to use for the task. |
| `queue` | `Optional[str]` | Optional override for the queue to use for the task. |
| `interruptible` | `Optional[bool]` | Optional override for the interruptible policy for the task. |
| `links` | `Tuple[Link, ...]` | Optional override for the Links associated with the task. |
| `kwargs` | `**kwargs` | Additional keyword arguments for further overrides. Some fields like name, image, docs, and interface cannot be overridden. |
**Returns:** A new TaskTemplate instance with the overridden parameters.
### post()
```python
def post(
return_vals: Any,
) -> Any
```
This is the postexecute function that will be
called after the task is executed
| Parameter | Type | Description |
|-|-|-|
| `return_vals` | `Any` | |
### pre()
```python
def pre(
args,
kwargs,
) -> Dict[str, Any]
```
This is the preexecute function that will be
called before the task is executed
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwargs` | `**kwargs` | |
### sql()
```python
def sql(
sctx: flyte.models.SerializationContext,
) -> typing.Optional[str]
```
Returns the SQL for the task. This is a set of key-value pairs that can be used to
configure the task execution environment at runtime. This is usually used by plugins.
| Parameter | Type | Description |
|-|-|-|
| `sctx` | `flyte.models.SerializationContext` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/packages/flyteplugins.snowflake/snowflakeconfig ===
# SnowflakeConfig
**Package:** `flyteplugins.snowflake`
Configure a Snowflake Task using a `SnowflakeConfig` object.
Additional connection parameters (role, authenticator, session_parameters, etc.) can be passed
via connection_kwargs.
See: https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
## Parameters
```python
class SnowflakeConfig(
account: str,
database: str,
schema: str,
warehouse: str,
user: str,
connection_kwargs: typing.Optional[typing.Dict[str, typing.Any]],
)
```
| Parameter | Type | Description |
|-|-|-|
| `account` | `str` | The Snowflake account identifier. |
| `database` | `str` | The Snowflake database name. |
| `schema` | `str` | The Snowflake schema name. |
| `warehouse` | `str` | The Snowflake warehouse name. |
| `user` | `str` | The Snowflake user name. |
| `connection_kwargs` | `typing.Optional[typing.Dict[str, typing.Any]]` | Optional dictionary of additional Snowflake connection parameters. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/snowflake/packages/flyteplugins.snowflake/snowflakeconnector ===
# SnowflakeConnector
**Package:** `flyteplugins.snowflake`
## Methods
| Method | Description |
|-|-|
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector > Methods > create()** | Submit a query to Snowflake asynchronously. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector > Methods > delete()** | Cancel a running Snowflake query. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector > Methods > get()** | Poll the status of a Snowflake query. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector > Methods > get_logs()** | Return the metrics for the task. |
| **Integrations > Snowflake > Packages > flyteplugins.snowflake > SnowflakeConnector > Methods > get_metrics()** | Return the metrics for the task. |
### create()
```python
def create(
task_template: flyteidl2.core.tasks_pb2.TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
snowflake_private_key: typing.Optional[str],
snowflake_private_key_passphrase: typing.Optional[str],
kwargs,
) -> flyteplugins.snowflake.connector.SnowflakeJobMetadata
```
Submit a query to Snowflake asynchronously.
| Parameter | Type | Description |
|-|-|-|
| `task_template` | `flyteidl2.core.tasks_pb2.TaskTemplate` | The Flyte task template containing the SQL query and configuration. |
| `inputs` | `typing.Optional[typing.Dict[str, typing.Any]]` | Optional dictionary of input parameters for parameterized queries. |
| `snowflake_private_key` | `typing.Optional[str]` | The private key content set as a Flyte secret. |
| `snowflake_private_key_passphrase` | `typing.Optional[str]` | The passphrase for the private key set as a Flyte secret, if any. |
| `kwargs` | `**kwargs` | |
**Returns:** A SnowflakeJobMetadata object containing the query ID and link to the query dashboard.
### delete()
```python
def delete(
resource_meta: flyteplugins.snowflake.connector.SnowflakeJobMetadata,
snowflake_private_key: typing.Optional[str],
snowflake_private_key_passphrase: typing.Optional[str],
kwargs,
)
```
Cancel a running Snowflake query.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.snowflake.connector.SnowflakeJobMetadata` | The SnowflakeJobMetadata containing the query ID. |
| `snowflake_private_key` | `typing.Optional[str]` | The private key content set as a Flyte secret. |
| `snowflake_private_key_passphrase` | `typing.Optional[str]` | The passphrase for the private key set as a Flyte secret, if any. |
| `kwargs` | `**kwargs` | |
### get()
```python
def get(
resource_meta: flyteplugins.snowflake.connector.SnowflakeJobMetadata,
snowflake_private_key: typing.Optional[str],
snowflake_private_key_passphrase: typing.Optional[str],
kwargs,
) -> flyte.connectors._connector.Resource
```
Poll the status of a Snowflake query.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyteplugins.snowflake.connector.SnowflakeJobMetadata` | The SnowflakeJobMetadata containing the query ID. |
| `snowflake_private_key` | `typing.Optional[str]` | The private key content set as a Flyte secret. |
| `snowflake_private_key_passphrase` | `typing.Optional[str]` | The passphrase for the private key set as a Flyte secret, if any. |
| `kwargs` | `**kwargs` | |
**Returns:** A Resource object containing the query results and a link to the query dashboard.
### get_logs()
```python
def get_logs(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskLogsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
### get_metrics()
```python
def get_metrics(
resource_meta: flyte.connectors._connector.ResourceMeta,
kwargs,
) -> flyteidl2.connector.connector_pb2.GetTaskMetricsResponse
```
Return the metrics for the task.
| Parameter | Type | Description |
|-|-|-|
| `resource_meta` | `flyte.connectors._connector.ResourceMeta` | |
| `kwargs` | `**kwargs` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark ===
# Spark
## Subpages
- **Integrations > Spark > Classes**
- **Integrations > Spark > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Spark > Packages > flyteplugins.spark > ParquetToSparkDecoder** | |
| **Integrations > Spark > Packages > flyteplugins.spark > Spark** |Use this to configure a SparkContext for a your task. |
| **Integrations > Spark > Packages > flyteplugins.spark > SparkToParquetEncoder** | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Spark > Packages > flyteplugins.spark** | |
## Subpages
- **Integrations > Spark > Packages > flyteplugins.spark**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/packages/flyteplugins.spark ===
# flyteplugins.spark
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Spark > Packages > flyteplugins.spark > ParquetToSparkDecoder** | |
| **Integrations > Spark > Packages > flyteplugins.spark > Spark** | Use this to configure a SparkContext for a your task. |
| **Integrations > Spark > Packages > flyteplugins.spark > SparkToParquetEncoder** | |
## Subpages
- **Integrations > Spark > Packages > flyteplugins.spark > ParquetToSparkDecoder**
- **Integrations > Spark > Packages > flyteplugins.spark > Spark**
- **Integrations > Spark > Packages > flyteplugins.spark > SparkToParquetEncoder**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/packages/flyteplugins.spark/parquettosparkdecoder ===
# ParquetToSparkDecoder
**Package:** `flyteplugins.spark`
## Parameters
```python
def ParquetToSparkDecoder()
```
Extend this abstract class, implement the decode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
and we have to get a Python value out of it. For the other way, see the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Spark > Packages > flyteplugins.spark > ParquetToSparkDecoder > Methods > decode()** | This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal. |
### decode()
```python
def decode(
flyte_value: flyteidl2.core.literals_pb2.StructuredDataset,
current_task_metadata: flyteidl2.core.literals_pb2.StructuredDatasetMetadata,
) -> pyspark.sql.dataframe.DataFrame
```
This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
value into a Python instance.
of those dataframes.
| Parameter | Type | Description |
|-|-|-|
| `flyte_value` | `flyteidl2.core.literals_pb2.StructuredDataset` | This will be a Flyte IDL DataFrame Literal - do not confuse this with the DataFrame class defined also in this module. |
| `current_task_metadata` | `flyteidl2.core.literals_pb2.StructuredDatasetMetadata` | Metadata object containing the type (and columns if any) for the currently executing task. This type may have more or less information than the type information bundled inside the incoming flyte_value. |
**Returns:** This function can either return an instance of the dataframe that this decoder handles, or an iterator
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/packages/flyteplugins.spark/spark ===
# Spark
**Package:** `flyteplugins.spark`
Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute
natively onto K8s as a distributed execution of spark
## Parameters
```python
class Spark(
spark_conf: typing.Optional[typing.Dict[str, str]],
hadoop_conf: typing.Optional[typing.Dict[str, str]],
executor_path: typing.Optional[str],
applications_path: typing.Optional[str],
driver_pod: typing.Optional[flyte._pod.PodTemplate],
executor_pod: typing.Optional[flyte._pod.PodTemplate],
)
```
| Parameter | Type | Description |
|-|-|-|
| `spark_conf` | `typing.Optional[typing.Dict[str, str]]` | Spark configuration dictionary. |
| `hadoop_conf` | `typing.Optional[typing.Dict[str, str]]` | Hadoop configuration dictionary. |
| `executor_path` | `typing.Optional[str]` | Path to the Python binary for PySpark execution. |
| `applications_path` | `typing.Optional[str]` | Path to the main application file. |
| `driver_pod` | `typing.Optional[flyte._pod.PodTemplate]` | Pod template for the driver pod. |
| `executor_pod` | `typing.Optional[flyte._pod.PodTemplate]` | Pod template for the executor pods. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/spark/packages/flyteplugins.spark/sparktoparquetencoder ===
# SparkToParquetEncoder
**Package:** `flyteplugins.spark`
## Parameters
```python
def SparkToParquetEncoder()
```
Extend this abstract class, implement the encode function, and register your concrete class with the
DataFrameTransformerEngine class in order for the core flytekit type engine to handle
dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
the DataFrameEncoder
## Properties
| Property | Type | Description |
|-|-|-|
| `protocol` | `None` | |
| `python_type` | `None` | |
| `supported_format` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Spark > Packages > flyteplugins.spark > SparkToParquetEncoder > Methods > encode()** | Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the. |
### encode()
```python
def encode(
dataframe: flyte.io._dataframe.dataframe.DataFrame,
structured_dataset_type: flyteidl2.core.types_pb2.StructuredDatasetType,
) -> flyteidl2.core.literals_pb2.StructuredDataset
```
Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the
incoming dataframe with defaults set for that dataframe
type. This simplifies this function's interface as a lot of data that could be specified by the user using
the
# TODO: Do we need to add a flag to indicate if it was wrapped by the transformer or by the user?
DataFrame wrapper class used as input to this function - that is the user facing Python class.
This function needs to return the IDL DataFrame.
| Parameter | Type | Description |
|-|-|-|
| `dataframe` | `flyte.io._dataframe.dataframe.DataFrame` | This is a DataFrame wrapper object. See more info above. |
| `structured_dataset_type` | `flyteidl2.core.types_pb2.StructuredDatasetType` | This the DataFrameType, as found in the LiteralType of the interface of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders can include it in the returned literals.DataFrame. See the IDL for more information on why this literal in particular carries the type information along with it. If the encoder doesn't supply it, it will also be filled in after the encoder runs by the transformer engine. |
**Returns:** This function should return a DataFrame literal object. Do not confuse this with the
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union ===
# Union
## Subpages
- **Integrations > Union > Classes**
- **Integrations > Union > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey** |Represents a Union API Key (OAuth Application). |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment** |Represents role/policy assignments for an identity. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Member** |Represents a Union organization member (user or application). |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy** |Represents a Union RBAC Policy. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role** |Represents a Union RBAC Role. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User** |Represents a Union user. |
| **Integrations > Union > Packages > flyteplugins.union.utils.auth > AppClientCredentials** |Application client credentials for API key. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.cli** | |
| **Integrations > Union > Packages > flyteplugins.union.internal.validate.validate.validate_pb2** | Generated protocol buffer code. |
| **Integrations > Union > Packages > flyteplugins.union.internal.validate.validate.validate_pb2_grpc** | Client and server classes corresponding to protobuf-defined services. |
| **Integrations > Union > Packages > flyteplugins.union.remote** | Union remote control plane objects. |
| **Integrations > Union > Packages > flyteplugins.union.utils.auth** | |
## Subpages
- **Integrations > Union > Packages > flyteplugins.union.cli**
- **Integrations > Union > Packages > flyteplugins.union.internal.validate.validate.validate_pb2**
- **Integrations > Union > Packages > flyteplugins.union.internal.validate.validate.validate_pb2_grpc**
- **Integrations > Union > Packages > flyteplugins.union.remote**
- **Integrations > Union > Packages > flyteplugins.union.utils.auth**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.cli ===
# flyteplugins.union.cli
## Directory
### Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.cli > Methods > edit_with_retry()** | Open an editor and retry or save to file on failure. |
## Methods
#### edit_with_retry()
```python
def edit_with_retry(
yaml_text: str,
apply_fn,
console,
noun: str,
)
```
Open an editor and retry or save to file on failure.
| Parameter | Type | Description |
|-|-|-|
| `yaml_text` | `str` | Initial YAML content to edit. |
| `apply_fn` | | Callable that takes the edited YAML string and applies it. Should raise on failure. |
| `console` | | Rich console for output. |
| `noun` | `str` | Name of the resource for messages (e.g. "role", "policy"). |
**Returns:** The result of apply_fn on success, or None if cancelled.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.internal.validate.validate.validate_pb2 ===
# flyteplugins.union.internal.validate.validate.validate_pb2
Generated protocol buffer code.
## Directory
### Variables
| Property | Type | Description |
|-|-|-|
| `DISABLED_FIELD_NUMBER` | `int` | |
| `HTTP_HEADER_NAME` | `int` | |
| `HTTP_HEADER_VALUE` | `int` | |
| `IGNORED_FIELD_NUMBER` | `int` | |
| `REQUIRED_FIELD_NUMBER` | `int` | |
| `RULES_FIELD_NUMBER` | `int` | |
| `UNKNOWN` | `int` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.internal.validate.validate.validate_pb2_grpc ===
# flyteplugins.union.internal.validate.validate.validate_pb2_grpc
Client and server classes corresponding to protobuf-defined services.
## Directory
### Variables
| Property | Type | Description |
|-|-|-|
| `GRPC_GENERATED_VERSION` | `str` | |
| `GRPC_VERSION` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote ===
# flyteplugins.union.remote
Union remote control plane objects.
This module provides remote object classes for Union-specific control plane
entities, following the same pattern as flyte.remote objects.
Example:
from flyteplugins.union.remote import ApiKey
# List all API keys
keys = ApiKey.listall()
for key in keys:
print(key.client_id)
# Create a new API key
api_key = ApiKey.create(name="my-ci-key")
print(api_key.client_secret)
# Get a specific API key
key = ApiKey.get(client_id="my-client-id")
# Delete an API key
ApiKey.delete(client_id="my-client-id")
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey** | Represents a Union API Key (OAuth Application). |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment** | Represents role/policy assignments for an identity. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Member** | Represents a Union organization member (user or application). |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy** | Represents a Union RBAC Policy. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role** | Represents a Union RBAC Role. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User** | Represents a Union user. |
## Subpages
- **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey**
- **Integrations > Union > Packages > flyteplugins.union.remote > Assignment**
- **Integrations > Union > Packages > flyteplugins.union.remote > Member**
- **Integrations > Union > Packages > flyteplugins.union.remote > Policy**
- **Integrations > Union > Packages > flyteplugins.union.remote > Role**
- **Integrations > Union > Packages > flyteplugins.union.remote > User**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/apikey ===
# ApiKey
**Package:** `flyteplugins.union.remote`
Represents a Union API Key (OAuth Application).
API Keys in Union are OAuth 2.0 applications that can be used for
headless authentication. They support client credentials flow for
machine-to-machine authentication.
Attributes:
pb2: The underlying protobuf App message
organization: The organization this API key belongs to (for serverless)
encoded_credentials: Base64-encoded credentials for UNION_API_KEY env var
Example:
# Create a new API key
api_key = ApiKey.create(name="ci-pipeline")
print(f"export FLYTE_API_KEY="{api_key.encoded_credentials}"")
# List all API keys
for key in ApiKey.listall():
print(f"{key.client_id}: {key.client_name}")
# Get a specific API key
key = ApiKey.get(client_id="my-client-id")
# Delete an API key
ApiKey.delete(client_id="my-client-id")
## Parameters
```python
class ApiKey(
pb2: App,
organization: str | None,
encoded_credentials: str | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `App` | |
| `organization` | `str \| None` | |
| `encoded_credentials` | `str \| None` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `client_id` | `None` | The OAuth client ID. |
| `client_name` | `None` | The human-readable name of the API key. |
| `client_secret` | `None` | The OAuth client secret (only available on creation). |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > create()** | Create a new API key. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > delete()** | Delete an API key. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > get()** | Get an API key by client ID. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > listall()** | List all API keys. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > to_json()** | Convert the object to a JSON string. |
| **Integrations > Union > Packages > flyteplugins.union.remote > ApiKey > Methods > update()** | Update an API key. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ApiKey.create.aio()`.
```python
def create(
cls,
name: str,
redirect_uris: list[str] | None,
) -> ApiKey
```
Create a new API key.
Example:
api_key = ApiKey.create(name="ci-pipeline")
print(f"Client ID: {api_key.client_id}")
print(f"Client Secret: {api_key.client_secret}")
print(f"Encoded: {api_key.encoded_credentials}")
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | Human-readable name for the API key |
| `redirect_uris` | `list[str] \| None` | OAuth redirect URIs (defaults to localhost callback) |
**Returns**
ApiKey instance with client_secret populated
**Raises**
| Exception | Description |
|-|-|
| `Exception` | If API key creation fails |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ApiKey.delete.aio()`.
```python
def delete(
cls,
client_id: str,
)
```
Delete an API key.
Example:
ApiKey.delete(client_id="old-ci-key")
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `client_id` | `str` | The OAuth client ID to delete |
**Raises**
| Exception | Description |
|-|-|
| `Exception` | If deletion fails |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ApiKey.get.aio()`.
```python
def get(
cls,
client_id: str,
) -> ApiKey
```
Get an API key by client ID.
Example:
key = ApiKey.get(client_id="my-client-id")
print(key.client_name)
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `client_id` | `str` | The OAuth client ID |
**Returns**
ApiKey instance
**Raises**
| Exception | Description |
|-|-|
| `Exception` | If API key not found |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ApiKey.listall.aio()`.
```python
def listall(
cls,
limit: int,
) -> AsyncIterator[ApiKey]
```
List all API keys.
Yields:
ApiKey instances
Example:
for key in ApiKey.listall(limit=10):
print(f"{key.client_id}: {key.client_name}")
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | Maximum number of keys to return |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await ApiKey.update.aio()`.
```python
def update(
cls,
client_id: str,
client_name: str | None,
redirect_uris: list[str] | None,
) -> ApiKey
```
Update an API key.
Example:
key = ApiKey.update(
client_id="my-key",
client_name="renamed-key"
)
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `client_id` | `str` | The OAuth client ID to update |
| `client_name` | `str \| None` | New name for the API key |
| `redirect_uris` | `list[str] \| None` | New redirect URIs |
**Returns**
Updated ApiKey instance
**Raises**
| Exception | Description |
|-|-|
| `Exception` | If update fails |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/assignment ===
# Assignment
**Package:** `flyteplugins.union.remote`
Represents role/policy assignments for an identity.
## Parameters
```python
class Assignment(
pb2: IdentityAssignment,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `IdentityAssignment` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `policies` | `None` | |
| `roles` | `None` | |
| `subject` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > create()** | Assign a policy to an identity. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > get()** | Get assignments for an identity. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > listall()** | List assignments for all members in the organization. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > to_json()** | Convert the object to a JSON string. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Assignment > Methods > unassign()** | Unassign a policy from an identity. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Assignment.create.aio()`.
```python
def create(
cls,
user_subject: str | None,
creds_subject: str | None,
email: str | None,
policy: str,
) -> Assignment
```
Assign a policy to an identity.
Exactly one of user_subject, creds_subject, or email must be provided.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `user_subject` | `str \| None` | User subject identifier. |
| `creds_subject` | `str \| None` | Client credentials application subject. |
| `email` | `str \| None` | User email for lookup. |
| `policy` | `str` | Policy name to assign. |
**Returns:** Assignment for the identity after the policy is assigned.
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Assignment.get.aio()`.
```python
def get(
cls,
user_subject: str | None,
creds_subject: str | None,
) -> Assignment
```
Get assignments for an identity.
One of user_subject or creds_subject must be provided.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `user_subject` | `str \| None` | |
| `creds_subject` | `str \| None` | |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Assignment.listall.aio()`.
```python
def listall(
cls,
limit: int,
) -> AsyncIterator[Assignment]
```
List assignments for all members in the organization.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### unassign()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Assignment.unassign.aio()`.
```python
def unassign(
cls,
user_subject: str | None,
creds_subject: str | None,
policy: str,
)
```
Unassign a policy from an identity.
One of user_subject or creds_subject must be provided.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `user_subject` | `str \| None` | |
| `creds_subject` | `str \| None` | |
| `policy` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/member ===
# Member
**Package:** `flyteplugins.union.remote`
Represents a Union organization member (user or application).
## Parameters
```python
class Member(
pb2: EnrichedIdentity,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `EnrichedIdentity` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `identity_type` | `None` | |
| `is_application` | `None` | |
| `is_user` | `None` | |
| `name` | `None` | |
| `subject` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > Member > Methods > listall()** | List all members in the organization. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Member > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Member > Methods > to_json()** | Convert the object to a JSON string. |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Member.listall.aio()`.
```python
def listall(
cls,
) -> AsyncIterator[Member]
```
List all members in the organization.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/policy ===
# Policy
**Package:** `flyteplugins.union.remote`
Represents a Union RBAC Policy.
## Parameters
```python
class Policy(
pb2: PolicyPb2,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `PolicyPb2` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `bindings` | `None` | |
| `description` | `None` | |
| `name` | `None` | |
| `organization` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > create()** | Create a new policy. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > delete()** | Delete a policy. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > get()** | Get a policy by name. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > listall()** | List all policies in the organization. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > to_json()** | Convert the object to a JSON string. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Policy > Methods > update()** | Update a policy by diffing bindings and applying add/remove operations. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Policy.create.aio()`.
```python
def create(
cls,
name: str,
description: str,
bindings: list[dict] | None,
) -> Policy
```
Create a new policy.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `description` | `str` | |
| `bindings` | `list[dict] \| None` | |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Policy.delete.aio()`.
```python
def delete(
cls,
name: str,
)
```
Delete a policy.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Policy.get.aio()`.
```python
def get(
cls,
name: str,
) -> Policy
```
Get a policy by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Policy.listall.aio()`.
```python
def listall(
cls,
limit: int,
) -> AsyncIterator[Policy]
```
List all policies in the organization.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Policy.update.aio()`.
```python
def update(
cls,
name: str,
old_bindings: list[dict],
new_bindings: list[dict],
) -> Policy
```
Update a policy by diffing bindings and applying add/remove operations.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `old_bindings` | `list[dict]` | |
| `new_bindings` | `list[dict]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/role ===
# Role
**Package:** `flyteplugins.union.remote`
Represents a Union RBAC Role.
## Parameters
```python
class Role(
pb2: RolePb2,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `RolePb2` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `actions` | `None` | |
| `description` | `None` | |
| `name` | `None` | |
| `organization` | `None` | |
| `role_type` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > create()** | Create a new role. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > delete()** | Delete a role. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > get()** | Get a role by name. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > listall()** | List all roles in the organization. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > to_json()** | Convert the object to a JSON string. |
| **Integrations > Union > Packages > flyteplugins.union.remote > Role > Methods > update()** | Update a role. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Role.create.aio()`.
```python
def create(
cls,
name: str,
description: str,
actions: list[str] | None,
) -> Role
```
Create a new role.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `description` | `str` | |
| `actions` | `list[str] \| None` | |
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Role.delete.aio()`.
```python
def delete(
cls,
name: str,
)
```
Delete a role.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Role.get.aio()`.
```python
def get(
cls,
name: str,
) -> Role
```
Get a role by name.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Role.listall.aio()`.
```python
def listall(
cls,
limit: int,
) -> AsyncIterator[Role]
```
List all roles in the organization.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
### update()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await Role.update.aio()`.
```python
def update(
cls,
name: str,
description: str,
actions: list[str] | None,
)
```
Update a role.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `name` | `str` | |
| `description` | `str` | |
| `actions` | `list[str] \| None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.remote/user ===
# User
**Package:** `flyteplugins.union.remote`
Represents a Union user.
## Parameters
```python
class User(
pb2: UserPb2,
)
```
| Parameter | Type | Description |
|-|-|-|
| `pb2` | `UserPb2` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `email` | `None` | |
| `first_name` | `None` | |
| `last_name` | `None` | |
| `subject` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > create()** | Create (invite) a new user. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > delete()** | Delete a user. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > get()** | Get a user by subject identifier. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > listall()** | List all users in the organization. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > to_dict()** | Convert the object to a JSON-serializable dictionary. |
| **Integrations > Union > Packages > flyteplugins.union.remote > User > Methods > to_json()** | Convert the object to a JSON string. |
### create()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await User.create.aio()`.
```python
def create(
cls,
first_name: str,
last_name: str,
email: str,
) -> User
```
Create (invite) a new user.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `first_name` | `str` | The user's first name. |
| `last_name` | `str` | The user's last name. |
| `email` | `str` | The user's email address. |
**Returns:** User instance for the newly created user.
### delete()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await User.delete.aio()`.
```python
def delete(
cls,
subject: str,
)
```
Delete a user.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `subject` | `str` | |
### get()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await User.get.aio()`.
```python
def get(
cls,
subject: str,
) -> User
```
Get a user by subject identifier.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `subject` | `str` | |
### listall()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await User.listall.aio()`.
```python
def listall(
cls,
limit: int,
email: str | None,
) -> AsyncIterator[User]
```
List all users in the organization.
| Parameter | Type | Description |
|-|-|-|
| `cls` | | |
| `limit` | `int` | Maximum number of users to return. |
| `email` | `str \| None` | Filter by email (server-side, exact match). |
### to_dict()
```python
def to_dict()
```
Convert the object to a JSON-serializable dictionary.
**Returns:** dict: A dictionary representation of the object.
### to_json()
```python
def to_json()
```
Convert the object to a JSON string.
**Returns:** str: A JSON string representation of the object.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.utils.auth ===
# flyteplugins.union.utils.auth
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.utils.auth > AppClientCredentials** | Application client credentials for API key. |
### Methods
| Method | Description |
|-|-|
| **Integrations > Union > Packages > flyteplugins.union.utils.auth > Methods > encode_app_client_credentials()** | Encode app credentials as a base64 string for use as UNION_API_KEY. |
| **Integrations > Union > Packages > flyteplugins.union.utils.auth > Methods > is_serverless_endpoint()** | Check if endpoint is a Union serverless endpoint. |
## Methods
#### encode_app_client_credentials()
```python
def encode_app_client_credentials(
app_credentials: flyteplugins.union.utils.auth.AppClientCredentials,
) -> str
```
Encode app credentials as a base64 string for use as UNION_API_KEY.
| Parameter | Type | Description |
|-|-|-|
| `app_credentials` | `flyteplugins.union.utils.auth.AppClientCredentials` | The application credentials to encode |
**Returns:** Base64-encoded credential string
#### is_serverless_endpoint()
```python
def is_serverless_endpoint(
endpoint: str,
) -> bool
```
Check if endpoint is a Union serverless endpoint.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str` | |
## Subpages
- **Integrations > Union > Packages > flyteplugins.union.utils.auth > AppClientCredentials**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/union/packages/flyteplugins.union.utils.auth/appclientcredentials ===
# AppClientCredentials
**Package:** `flyteplugins.union.utils.auth`
Application client credentials for API key.
## Parameters
```python
class AppClientCredentials(
endpoint: str,
client_id: str,
client_secret: str,
org: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str` | |
| `client_id` | `str` | |
| `client_secret` | `str` | |
| `org` | `str` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/vllm ===
# vLLM
## Subpages
- **Integrations > vLLM > Classes**
- **Integrations > vLLM > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/vllm/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment** |App environment backed by vLLM for serving large language models. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/vllm/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > vLLM > Packages > flyteplugins.vllm** | |
## Subpages
- **Integrations > vLLM > Packages > flyteplugins.vllm**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/vllm/packages/flyteplugins.vllm ===
# flyteplugins.vllm
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment** | App environment backed by vLLM for serving large language models. |
### Variables
| Property | Type | Description |
|-|-|-|
| `DEFAULT_VLLM_IMAGE` | `Image` | |
## Subpages
- **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/vllm/packages/flyteplugins.vllm/vllmappenvironment ===
# VLLMAppEnvironment
**Package:** `flyteplugins.vllm`
App environment backed by vLLM for serving large language models.
This environment sets up a vLLM server with the specified model and configuration.
## Parameters
```python
class VLLMAppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
parameters: List[Parameter],
cluster_pool: str,
timeouts: Timeouts,
image: str | Image | Literal['auto'],
type: str,
port: int | Port,
extra_args: str | list[str],
model_path: str | RunOutput,
model_hf_path: str,
model_id: str,
stream_model: bool,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the application. |
| `depends_on` | `List[Environment]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | Secrets that are requested for application. |
| `env_vars` | `Optional[Dict[str, str]]` | Environment variables to set for the application. |
| `resources` | `Optional[Resources]` | |
| `interruptible` | `bool` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | Whether the public URL requires authentication. |
| `scaling` | `Scaling` | Scaling configuration for the app environment. |
| `domain` | `Domain \| None` | Domain to use for the app. |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `parameters` | `List[Parameter]` | |
| `cluster_pool` | `str` | The target cluster_pool where the app should be deployed. |
| `timeouts` | `Timeouts` | |
| `image` | `str \| Image \| Literal['auto']` | |
| `type` | `str` | Type of app. |
| `port` | `int \| Port` | Port application listens to. Defaults to 8000 for vLLM. |
| `extra_args` | `str \| list[str]` | Extra args to pass to `vllm serve`. See https://docs.vllm.ai/en/stable/configuration/engine_args or run `vllm serve --help` for details. |
| `model_path` | `str \| RunOutput` | Remote path to model (e.g., s3 |
| `model_hf_path` | `str` | Hugging Face path to model (e.g., Qwen/Qwen3-0.6B). |
| `model_id` | `str` | Model id that is exposed by vllm. |
| `stream_model` | `bool` | Set to True to stream model from blob store to the GPU directly. If False, the model will be downloaded to the local file system first and then loaded into the GPU. |
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > add_dependency()** | Add a dependency to the environment. |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > clone_with()** | |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > container_args()** | Return the container arguments for vLLM. |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > container_cmd()** | |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > get_port()** | |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > on_shutdown()** | Decorator to define the shutdown function for the app environment. |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > on_startup()** | Decorator to define the startup function for the app environment. |
| **Integrations > vLLM > Packages > flyteplugins.vllm > VLLMAppEnvironment > Methods > server()** | Decorator to define the server function for the app environment. |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[list[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> VLLMAppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[list[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialization_context: SerializationContext,
) -> list[str]
```
Return the container arguments for vLLM.
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
parameter_overrides: list[Parameter] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `parameter_overrides` | `list[Parameter] \| None` | |
### get_port()
```python
def get_port()
```
### on_shutdown()
```python
def on_shutdown(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the shutdown function for the app environment.
This function is called after the server function is called.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### on_startup()
```python
def on_startup(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the startup function for the app environment.
This function is called before the server function is called.
The decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
### server()
```python
def server(
fn: Callable[..., None],
) -> Callable[..., None]
```
Decorator to define the server function for the app environment.
This decorated function can be a sync or async function, and accepts input
parameters based on the Parameters defined in the AppEnvironment
definition.
| Parameter | Type | Description |
|-|-|-|
| `fn` | `Callable[..., None]` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb ===
# Weights & Biases
## Subpages
- **Integrations > Weights & Biases > Classes**
- **Integrations > Weights & Biases > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb/classes ===
# Classes
| Class | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Wandb** |Generates a Weights & Biases run link. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > WandbSweep** |Generates a Weights & Biases Sweep link. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb/packages ===
# Packages
| Package | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb** | ## Key features:. |
## Subpages
- **Integrations > Weights & Biases > Packages > flyteplugins.wandb**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb/packages/flyteplugins.wandb ===
# flyteplugins.wandb
## Key features:
- Automatic W&B run initialization with `@wandb_init` decorator
- Automatic W&B links in Flyte UI pointing to runs and sweeps
- Parent/child task support with automatic run reuse
- W&B sweep creation and management with `@wandb_sweep` decorator
- Configuration management with `wandb_config()` and `wandb_sweep_config()`
- Distributed training support (auto-detects PyTorch DDP/torchrun)
## Basic usage:
1. Simple task with W&B logging:
```python
from flyteplugins.wandb import wandb_init, get_wandb_run
@wandb_init(project="my-project", entity="my-team")
@env.task
async def train_model(learning_rate: float) -> str:
wandb_run = get_wandb_run()
wandb_run.log({"loss": 0.5, "learning_rate": learning_rate})
return wandb_run.id
```
2. Parent/Child Tasks with Run Reuse:
```python
@wandb_init # Automatically reuses parent's run ID
@env.task
async def child_task(x: int) -> str:
wandb_run = get_wandb_run()
wandb_run.log({"child_metric": x * 2})
return wandb_run.id
@wandb_init(project="my-project", entity="my-team")
@env.task
async def parent_task() -> str:
wandb_run = get_wandb_run()
wandb_run.log({"parent_metric": 100})
# Child reuses parent's run by default (run_mode="auto")
await child_task(5)
return wandb_run.id
```
3. Configuration with context manager:
```python
from flyteplugins.wandb import wandb_config
r = flyte.with_runcontext(
custom_context=wandb_config(
project="my-project",
entity="my-team",
tags=["experiment-1"]
)
).run(train_model, learning_rate=0.001)
```
4. Creating new runs for child tasks:
```python
@wandb_init(run_mode="new") # Always creates a new run
@env.task
async def independent_child() -> str:
wandb_run = get_wandb_run()
wandb_run.log({"independent_metric": 42})
return wandb_run.id
```
5. Running sweep agents in parallel:
```python
import asyncio
from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id, get_wandb_context
@wandb_init
async def objective():
wandb_run = wandb.run
config = wandb_run.config
...
wandb_run.log({"loss": loss_value})
@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
wandb.agent(sweep_id, function=objective, count=count, project=get_wandb_context().project)
return agent_id
@wandb_sweep
@env.task
async def run_parallel_sweep(num_agents: int = 2, trials_per_agent: int = 5) -> str:
sweep_id = get_wandb_sweep_id()
# Launch agents in parallel
agent_tasks = [
sweep_agent(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent)
for i in range(num_agents)
]
# Wait for all agents to complete
await asyncio.gather(*agent_tasks)
return sweep_id
# Run with 2 parallel agents
r = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
}
)
}
).run(run_parallel_sweep, num_agents=2, trials_per_agent=5)
```
6. Distributed Training Support:
The plugin auto-detects distributed training from environment variables
(RANK, WORLD_SIZE, LOCAL_RANK, etc.) set by torchrun/torch.distributed.elastic.
The `rank_scope` parameter controls the scope of run creation:
- `"global"` (default): Global scope - 1 run/group across all workers
- `"worker"`: Worker scope - 1 run/group per worker
By default (`run_mode="auto"`, `rank_scope="global"`):
- Single-node: Only rank 0 logs (1 run)
- Multi-node: Only global rank 0 logs (1 run)
```python
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "5Gi"), gpu="V100:4"),
plugin_config=Elastic(nnodes=2, nproc_per_node=2),
)
@wandb_init
@torch_env.task
async def train_distributed():
torch.distributed.init_process_group("nccl")
# Only global rank 0 gets a W&B run, other ranks get None
run = get_wandb_run()
if run:
run.log({"loss": loss})
return run.id if run else "non-primary-rank"
```
Use `rank_scope="worker"` to get 1 run per worker:
```python
@wandb_init(rank_scope="worker")
@torch_env.task
async def train_distributed_per_worker():
# Multi-node: local rank 0 of each worker gets a W&B run (1 run per worker)
run = get_wandb_run()
if run:
run.log({"loss": loss})
return run.id if run else "non-primary-rank"
```
Use `run_mode="shared"` for all ranks to log to shared run(s):
```python
@wandb_init(run_mode="shared") # rank_scope="global": 1 shared run across all ranks
@torch_env.task
async def train_distributed_shared():
# All ranks log to the same W&B run (with x_label to identify each rank)
run = get_wandb_run()
run.log({"rank_metric": value})
return run.id
@wandb_init(run_mode="shared", rank_scope="worker") # 1 shared run per worker
@torch_env.task
async def train_distributed_shared_per_worker():
run = get_wandb_run()
run.log({"rank_metric": value})
return run.id
```
Use `run_mode="new"` for each rank to have its own W&B run:
```python
@wandb_init(run_mode="new") # rank_scope="global": all runs in 1 group
@torch_env.task
async def train_distributed_separate_runs():
# Each rank gets its own W&B run (grouped in W&B UI)
# Run IDs: {base}-rank-{global_rank}
run = get_wandb_run()
run.log({"rank_metric": value})
return run.id
@wandb_init(run_mode="new", rank_scope="worker") # runs grouped per worker
@torch_env.task
async def train_distributed_separate_runs_per_worker():
run = get_wandb_run()
run.log({"rank_metric": value})
return run.id
```
Decorator order: `@wandb_init` or `@wandb_sweep` must be the outermost decorator:
```python
@wandb_init
@env.task
async def my_task():
...
```
## Directory
### Classes
| Class | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Wandb** | Generates a Weights & Biases run link. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > WandbSweep** | Generates a Weights & Biases Sweep link. |
### Methods
| Method | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > download_wandb_run_dir()** | Download wandb run data from wandb cloud. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > download_wandb_run_logs()** | Traced function to download wandb run logs after task completion. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > download_wandb_sweep_dirs()** | Download all run data for a wandb sweep. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > download_wandb_sweep_logs()** | Traced function to download wandb sweep logs after task completion. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_distributed_info()** | Get distributed training info if running in a distributed context. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_wandb_context()** | Get wandb config from current Flyte context. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_wandb_run()** | Get the current wandb run if within a `@wandb_init` decorated task or trace. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_wandb_run_dir()** | Get the local directory path for the current wandb run. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_wandb_sweep_context()** | Get wandb sweep config from current Flyte context. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > get_wandb_sweep_id()** | Get the current wandb `sweep_id` if within a `@wandb_sweep` decorated task. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > wandb_config()** | Create wandb configuration. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > wandb_init()** | Decorator to automatically initialize wandb for Flyte tasks and wandb sweep objectives. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > wandb_sweep()** | Decorator to create a wandb sweep and make `sweep_id` available. |
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Methods > wandb_sweep_config()** | Create wandb sweep configuration for hyperparameter optimization. |
## Methods
#### download_wandb_run_dir()
```python
def download_wandb_run_dir(
run_id: typing.Optional[str],
path: typing.Optional[str],
include_history: bool,
) -> str
```
Download wandb run data from wandb cloud.
Downloads all run files and optionally exports metrics history to JSON.
This enables access to wandb data from any task or after workflow completion.
Downloaded contents:
- summary.json - final summary metrics (always exported)
- metrics_history.json - step-by-step metrics (if include_history=True)
- Plus any files synced by wandb (requirements.txt, wandb_metadata.json, etc.)
| Parameter | Type | Description |
|-|-|-|
| `run_id` | `typing.Optional[str]` | The wandb run ID to download. If `None`, uses the current run's ID from context (useful for shared runs across tasks). |
| `path` | `typing.Optional[str]` | Local directory to download files to. If `None`, downloads to `/tmp/wandb_runs/{run_id}`. |
| `include_history` | `bool` | If `True`, exports the step-by-step metrics history to `metrics_history.json`. Defaults to `True`. |
**Returns**
Local path where files were downloaded.
**Raises**
| Exception | Description |
|-|-|
| ``RuntimeError`` | If no `run_id` provided and no active run in context. |
| ``wandb.errors.CommError`` | If run not found in wandb cloud. |
> [!NOTE]
> There may be a brief delay between when files are written locally and
> when they're available in wandb cloud. For immediate local access
> within the same task, use `get_wandb_run_dir()` instead.
#### download_wandb_run_logs()
CODE11
Traced function to download wandb run logs after task completion.
This function is called automatically when `download_logs=True` is set
in `@wandb_init` or `wandb_config()`. The downloaded files appear as a
trace output in the Flyte UI.
| Parameter | Type | Description |
|-|-|-|
| `run_id` | `str` | The wandb run ID to download. |
**Returns**
Dir containing the downloaded wandb run files.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If download fails (network error, run not found, auth failure, etc.) |
#### download_wandb_sweep_dirs()
CODE12
Download all run data for a wandb sweep.
Queries the wandb API for all runs in the sweep and downloads their files
and metrics history. This is useful for collecting results from all sweep
trials after completion.
| Parameter | Type | Description |
|-|-|-|
| `sweep_id` | `typing.Optional[str]` | The wandb sweep ID. If `None`, uses the current sweep's ID from context (set by `@wandb_sweep` decorator). |
| `base_path` | `typing.Optional[str]` | Base directory to download files to. Each run's files will be in a subdirectory named by run_id. If `None`, uses `/tmp/wandb_runs/`. |
| `include_history` | `bool` | If `True`, exports the step-by-step metrics history to metrics_history.json for each run. Defaults to `True`. |
**Returns**
List of local paths where run data was downloaded.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If no sweep_id provided and no active sweep in context. |
| `wandb.errors.CommError` | If sweep not found in wandb cloud. |
#### download_wandb_sweep_logs()
CODE13
Traced function to download wandb sweep logs after task completion.
This function is called automatically when `download_logs=True` is set
in `@wandb_sweep` or `wandb_sweep_config()`. The downloaded files appear as a
trace output in the Flyte UI.
| Parameter | Type | Description |
|-|-|-|
| `sweep_id` | `str` | The wandb sweep ID to download. |
**Returns**
Dir containing the downloaded wandb sweep run files.
**Raises**
| Exception | Description |
|-|-|
| `RuntimeError` | If download fails (network error, sweep not found, auth failure, etc.) |
#### get_distributed_info()
CODE14
Get distributed training info if running in a distributed context.
This function auto-detects distributed training from environment variables
set by torchrun/torch.distributed.elastic.
**Returns**
dict | None: Dictionary with distributed info or None if not distributed.
- rank: Global rank (0 to world_size-1)
- local_rank: Rank within the node (0 to local_world_size-1)
- world_size: Total number of processes
- local_world_size: Processes per node
- worker_index: Node/worker index (0 to num_workers-1)
- num_workers: Total number of nodes/workers
#### get_wandb_context()
CODE15
Get wandb config from current Flyte context.
#### get_wandb_run()
CODE16
Get the current wandb run if within a `@wandb_init` decorated task or trace.
The run is initialized when the `@wandb_init` context manager is entered.
Returns None if not within a `wandb_init` context.
**Returns:** `wandb.sdk.wandb_run.Run` | `None`: The current wandb run object or None.
#### get_wandb_run_dir()
CODE17
Get the local directory path for the current wandb run.
Use this for accessing files written by the current task without any
network calls. For accessing files from other tasks (or after a task
completes), use `download_wandb_run_dir()` instead.
**Returns**
Local path to wandb run directory (`wandb.run.dir`) or `None` if no
active run.
#### get_wandb_sweep_context()
CODE18
Get wandb sweep config from current Flyte context.
#### get_wandb_sweep_id()
CODE19
Get the current wandb `sweep_id` if within a `@wandb_sweep` decorated task.
Returns `None` if not within a `wandb_sweep` context.
**Returns:** `str` | `None`: The sweep ID or None.
#### wandb_config()
CODE20
Create wandb configuration.
This function works in two contexts:
1. With `flyte.with_runcontext()` - sets global wandb config
2. As a context manager - overrides config for specific tasks
| Parameter | Type | Description |
|-|-|-|
| `project` | `typing.Optional[str]` | W&B project name |
| `entity` | `typing.Optional[str]` | W&B entity (team or username) |
| `id` | `typing.Optional[str]` | Unique run id (auto-generated if not provided) |
| `name` | `typing.Optional[str]` | Human-readable run name |
| `tags` | `typing.Optional[list[str]]` | List of tags for organizing runs |
| `config` | `typing.Optional[dict[str, typing.Any]]` | Dictionary of hyperparameters |
| `mode` | `typing.Optional[str]` | "online", "offline" or "disabled" |
| `group` | `typing.Optional[str]` | Group name for related runs |
| `run_mode` | `typing.Literal['auto', 'new', 'shared']` | "auto", "new" or "shared". Controls whether tasks create new W&B runs or share existing ones. - "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run - "new": Always creates a new wandb run with a unique ID - "shared": Always shares the parent's run ID In distributed training context (single-node): - "auto" (default): Only rank 0 logs. - "shared": All ranks log to a single shared W&B run. - "new": Each rank gets its own W&B run (grouped in W&B UI). Multi-node: behavior depends on `rank_scope`. |
| `rank_scope` | `typing.Literal['global', 'worker']` | "global" or "worker". Controls which ranks log in distributed training. run_mode="auto": - "global" (default): Only global rank 0 logs (1 run total). - "worker": Local rank 0 of each worker logs (1 run per worker). run_mode="shared": - "global": All ranks log to a single shared W&B run. - "worker": Ranks per worker log to a single shared W&B run (1 run per worker). run_mode="new": - "global": Each rank gets its own W&B run (1 run total). - "worker": Each rank gets its own W&B run grouped per worker -> N runs. |
| `download_logs` | `bool` | If `True`, downloads wandb run files after task completes and shows them as a trace output in the Flyte UI |
| `kwargs` | `**kwargs` | |
#### wandb_init()
CODE21
Decorator to automatically initialize wandb for Flyte tasks and wandb sweep objectives.
Decorator Order:
For tasks, @wandb_init must be the outermost decorator:
@wandb_init
@env.task
async def my_task():
...
This decorator:
1. Initializes wandb when the context manager is entered
2. Auto-generates unique run ID from Flyte action context if not provided
3. Makes the run available via get_wandb_run()
4. Automatically adds a W&B link to the task in the Flyte UI
5. Automatically finishes the run after completion
6. Optionally downloads run logs as a trace output (if download_logs=True)
| Parameter | Type | Description |
|-|-|-|
| `_func` | `typing.Optional[~F]` | |
| `run_mode` | `typing.Optional[typing.Literal['auto', 'new', 'shared']]` | Controls whether to create a new W&B run or share an existing one: - "auto" (default): Creates new run if no parent run exists, otherwise shares parent's run - "new": Always creates a new wandb run with a unique ID - "shared": Always shares the parent's run ID (useful for child tasks) In distributed training context (single-node): - "auto" (default): Only rank 0 logs. - "shared": All ranks log to a single shared W&B run. - "new": Each rank gets its own W&B run (grouped in W&B UI). Multi-node: behavior depends on `rank_scope`. |
| `rank_scope` | `typing.Optional[typing.Literal['global', 'worker']]` | Flyte-specific rank scope - "global" or "worker". Controls which ranks log in distributed training. run_mode="auto": - "global" (default): Only global rank 0 logs (1 run total). - "worker": Local rank 0 of each worker logs (1 run per worker). run_mode="shared": - "global": All ranks log to a single shared W&B run. - "worker": Ranks per worker log to a single shared W&B run (1 run per worker). run_mode="new": - "global": Each rank gets its own W&B run (1 run total). - "worker": Each rank gets its own W&B run grouped per worker -> N runs. |
| `download_logs` | `typing.Optional[bool]` | If `True`, downloads wandb run files after task completes and shows them as a trace output in the Flyte UI. If None, uses the value from `wandb_config()` context if set. |
| `project` | `typing.Optional[str]` | W&B project name (overrides context config if provided) |
| `entity` | `typing.Optional[str]` | W&B entity/team name (overrides context config if provided) |
| `kwargs` | `**kwargs` | |
#### wandb_sweep()
CODE22
Decorator to create a wandb sweep and make `sweep_id` available.
This decorator:
1. Creates a wandb sweep using config from context
2. Makes `sweep_id` available via `get_wandb_sweep_id()`
3. Automatically adds a W&B sweep link to the task
4. Optionally downloads all sweep run logs as a trace output (if `download_logs=True`)
Decorator Order:
For tasks, @wandb_sweep must be the outermost decorator:
@wandb_sweep
@env.task
async def my_task():
...
| Parameter | Type | Description |
|-|-|-|
| `_func` | `typing.Optional[~F]` | |
| `project` | `typing.Optional[str]` | W&B project name (overrides context config if provided) |
| `entity` | `typing.Optional[str]` | W&B entity/team name (overrides context config if provided) |
| `download_logs` | `typing.Optional[bool]` | if `True`, downloads all sweep run files after task completes and shows them as a trace output in the Flyte UI. If None, uses the value from wandb_sweep_config() context if set. |
| `kwargs` | `**kwargs` | |
#### wandb_sweep_config()
CODE23
Create wandb sweep configuration for hyperparameter optimization.
See: https://docs.wandb.ai/models/sweeps/sweep-config-keys
| Parameter | Type | Description |
|-|-|-|
| `method` | `typing.Optional[str]` | Sweep method (e.g., "random", "grid", "bayes") |
| `metric` | `typing.Optional[dict[str, typing.Any]]` | Metric to optimize (e.g., {"name": "loss", "goal": "minimize"}) |
| `parameters` | `typing.Optional[dict[str, typing.Any]]` | Parameter definitions for the sweep |
| `project` | `typing.Optional[str]` | W&B project for the sweep |
| `entity` | `typing.Optional[str]` | W&B entity for the sweep |
| `prior_runs` | `typing.Optional[list[str]]` | List of prior run IDs to include in the sweep analysis |
| `name` | `typing.Optional[str]` | Sweep name (auto-generated as `{run_name}-{action_name}` if not provided) |
| `download_logs` | `bool` | If `True`, downloads all sweep run files after task completes and shows them as a trace output in the Flyte UI |
| `kwargs` | `**kwargs` | |
## Subpages
- **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Wandb**
- **Integrations > Weights & Biases > Packages > flyteplugins.wandb > WandbSweep**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb/packages/flyteplugins.wandb/wandb ===
# Wandb
**Package:** `flyteplugins.wandb`
Generates a Weights & Biases run link.
## Parameters
```python
class Wandb(
host: str,
project: typing.Optional[str],
entity: typing.Optional[str],
run_mode: typing.Literal['auto', 'new', 'shared'],
rank_scope: typing.Literal['global', 'worker'],
id: typing.Optional[str],
name: str,
_is_distributed: bool,
_worker_index: typing.Optional[int],
)
```
| Parameter | Type | Description |
|-|-|-|
| `host` | `str` | Base W&B host URL |
| `project` | `typing.Optional[str]` | W&B project name (overrides context config if provided) |
| `entity` | `typing.Optional[str]` | W&B entity/team name (overrides context config if provided) |
| `run_mode` | `typing.Literal['auto', 'new', 'shared']` | Determines the link behavior: - "auto" (default): Use parent's run if available, otherwise create new - "new": Always creates a new wandb run with a unique ID - "shared": Always shares the parent's run ID In distributed training context (single-node): - "auto" (default): Only rank 0 logs. - "shared": All ranks log to a single shared W&B run. - "new": Each rank gets its own W&B run (grouped in W&B UI). Multi-node: behavior depends on `rank_scope`. |
| `rank_scope` | `typing.Literal['global', 'worker']` | Flyte-specific rank scope - "global" or "worker". Controls which ranks log in distributed training. run_mode="auto": - "global" (default): Only global rank 0 logs (1 run total). - "worker": Local rank 0 of each worker logs (1 run per worker). run_mode="shared": - "global": All ranks log to a single shared W&B run. - "worker": Ranks per worker log to a single shared W&B run (1 run per worker). run_mode="new": - "global": Each rank gets its own W&B run (1 run total). - "worker": Each rank gets its own W&B run grouped per worker -> N runs. |
| `id` | `typing.Optional[str]` | Optional W&B run ID (overrides context config if provided) |
| `name` | `str` | Link name in the Flyte UI |
| `_is_distributed` | `bool` | |
| `_worker_index` | `typing.Optional[int]` | |
## Methods
| Method | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > Wandb > Methods > get_link()** | Returns a task log link given the action. |
### get_link()
```python
def get_link(
run_name: str,
project: str,
domain: str,
context: typing.Dict[str, str],
parent_action_name: str,
action_name: str,
pod_name: str,
kwargs,
) -> str
```
Returns a task log link given the action.
Link can have template variables that are replaced by the backend.
| Parameter | Type | Description |
|-|-|-|
| `run_name` | `str` | The name of the run. |
| `project` | `str` | The project name. |
| `domain` | `str` | The domain name. |
| `context` | `typing.Dict[str, str]` | Additional context for generating the link. |
| `parent_action_name` | `str` | The name of the parent action. |
| `action_name` | `str` | The name of the action. |
| `pod_name` | `str` | The name of the pod. |
| `kwargs` | `**kwargs` | Additional keyword arguments. |
**Returns:** The generated link.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/integrations/wandb/packages/flyteplugins.wandb/wandbsweep ===
# WandbSweep
**Package:** `flyteplugins.wandb`
Generates a Weights & Biases Sweep link.
## Parameters
```python
class WandbSweep(
host: str,
project: typing.Optional[str],
entity: typing.Optional[str],
id: typing.Optional[str],
name: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `host` | `str` | Base W&B host URL |
| `project` | `typing.Optional[str]` | W&B project name (overrides context config if provided) |
| `entity` | `typing.Optional[str]` | W&B entity/team name (overrides context config if provided) |
| `id` | `typing.Optional[str]` | Optional W&B sweep ID (overrides context config if provided) |
| `name` | `str` | Link name in the Flyte UI |
## Methods
| Method | Description |
|-|-|
| **Integrations > Weights & Biases > Packages > flyteplugins.wandb > WandbSweep > Methods > get_link()** | Returns a task log link given the action. |
### get_link()
```python
def get_link(
run_name: str,
project: str,
domain: str,
context: typing.Dict[str, str],
parent_action_name: str,
action_name: str,
pod_name: str,
kwargs,
) -> str
```
Returns a task log link given the action.
Link can have template variables that are replaced by the backend.
| Parameter | Type | Description |
|-|-|-|
| `run_name` | `str` | The name of the run. |
| `project` | `str` | The project name. |
| `domain` | `str` | The domain name. |
| `context` | `typing.Dict[str, str]` | Additional context for generating the link. |
| `parent_action_name` | `str` | The name of the parent action. |
| `action_name` | `str` | The name of the action. |
| `pod_name` | `str` | The name of the pod. |
| `kwargs` | `**kwargs` | Additional keyword arguments. |
**Returns:** The generated link.
=== PAGE: https://www.union.ai/docs/v2/flyte/community ===
# Community
Flyte is an open source project that is built and maintained by a community of contributors.
Union AI is the primary maintainer of Flyte and developer of Union.ai, a closed source commercial product that is built on top of Flyte.
Since the success of Flyte is essential to the success of Union.ai,
the company is dedicated to building and expanding the Flyte open source project and community.
For information on how to get involved and how to keep in touch, see **Joining the community**.
## Contributing to the codebase
The full Flyte codebase is open source and available on GitHub.
If you are interested, see **Contributing code**.
## Contributing to documentation
Union AI maintains and hosts both Flyte and Union documentation at [www.union.ai/docs](/docs/v2/root/).
The two sets of documentation are deeply integrated, as the Union product is built on top of Flyte.
To better maintain both sets of docs, they are hosted in the same repository (but rendered so that you can choose to view one or the other).
Both the Flyte and Union documentation are open source.
Flyte community members and Union customers are both welcome to contribute to the documentation.
If you are interested, see [Contributing documentation and examples](./contributing-docs/_index).
## Subpages
- **Joining the community**
- **Contributing code**
- **Contributing docs and examples**
=== PAGE: https://www.union.ai/docs/v2/flyte/community/joining-the-community ===
# Joining the community
Keeping the lines of communication open is important in growing and maintain the Flyte community.
Please join us on:
[](https://slack.flyte.org)
[](https://github.com/flyteorg/flyte/discussions)
[](https://twitter.com/flyteorg)
[](https://www.linkedin.com/groups/13962256)
## Community sync
1. **When**: First Tuesday of every month, 9:00 AM Pacific Time.
2. **Where**: Live streamed on [YouTube](https://www.youtube.com/@flyteorg/streams) and [LinkedIn](https://www.linkedin.com/company/union-ai/events/).
3. **Watch the recordings**: [here](https://www.youtube.com/live/d81Jd4rfmzw?feature=shared).
4. **Import the public calendar**: [here](https://lists.lfaidata.foundation/g/flyte-announce/ics/12031983/2145304139/feed.ics) to not miss any event.
5. **Want to present?** Fill out [this form](https://tally.so/r/wgN8LM). We're eager to learn from you!
You're welcome to join and learn from other community members sharing their experiences with Flyte or any other technology from the AI ecosystem.
## Contributor's sync
1. **When**: Every 2 weeks on Thursdays. Alternating schedule between 11:00 AM PT and 7:00 AM PT.
2. **Where**: Live on [Zoom](https://zoom-lfx.platform.linuxfoundation.org/meeting/92309721545?password=c93d76a7-801a-47c6-9916-08e38e5a5c1f).
3. **Purpose**: Address questions from new contributors, discuss active initiatives, and RFCs.
4. **Import the public calendar**: [here](https://lists.lfaidata.foundation/g/flyte-announce/ics/12031983/2145304139/feed.ics) to not miss any event.
## Newsletter
[Join the Flyte mailing list](https://lists.lfaidata.foundation/g/flyte-announce/join) to receive the monthly newsletter.
## Slack guidelines
Flyte strives to build and maintain an open, inclusive, productive, and self-governing open source community.
In consequence, we expect all community members to respect the following guidelines:
### Abide by the [LF's Code of Conduct](https://lfprojects.org/policies/code-of-conduct/)
As a Linux Foundation project, we must enforce the rules that govern professional and positive open source communities.
### Avoid using DMs and @mentions
Whenever possible, post your questions and responses in public channels so other community members can benefit from the conversation and outcomes.
Exceptions to this are when you need to share private or sensitive information.
In such a case, the outcome should still be shared publicly.
Limit the use of `@mentions` of other community members to be considerate of notification noise.
### Make use of threads
Threads help us keep conversations contained and organized, reducing the time it takes to give you the support you need.
**Thread best practices:**
- Don't break your question into multiple messages. Put everything in one.
- For long questions, write a few sentences in the first message, and put the rest in a thread.
- If there's a code snippet (more than 5 lines of code), put it inside the thread.
- Avoid using the βAlso send to channelβ feature unless it's really necessary.
- If your question contains multiple questions, make sure to break them into multiple messages, so each could be answered in a separate thread.
### Do not post the same question across multiple channels
If you consider that a question needs to be shared on other channels, ask it once and then indicate explicitly that you're cross-posting.
If you're having a tough time getting the support you need (or aren't sure where to go!), please DM `@David Espejo` or `@Samhita Alla` for support.
### Do not solicit members of our Slack
The Flyte community exists to collaborate with, learn from, and support one another.
It is not a space to pitch your products or services directly to our members via public channels, private channels, or direct messages.
We are excited to have a growing presence from vendors to help answer questions from community members as they may arise, but we have a strict 3-strike policy against solicitation:
- **First occurrence**: We'll give you a friendly but public reminder that the behavior is inappropriate according to our guidelines.
- **Second occurrence**: We'll send you a DM warning that any additional violations will result in removal from the community.
- **Third occurrence**: We'll delete or ban your account.
We reserve the right to ban users without notice if they are clearly spamming our community members.
If you want to promote a product or service, go to the `#shameless-promotion` channel and make sure to follow these rules:
- Don't post more than two promotional posts per week.
- Non-relevant topics aren't allowed.
Messages that don't follow these rules will be deleted.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-code ===
# Contributing code
Thank you for your interest in Flyte!
> [!NOTE]
> This page is part of the Flyte 2 documentation.
> If you are interested in contributing code to for Flyte 1, switch the selector at the top of the page to \*v1\*\*.
## Flyte 2
The Flyte 2 SDK source code is available on [GitHub](https://github.com/flyteorg/flyte-sdk) under the same Apache license as the original Flyte 1.
You are welcome to take a look, [download the package](https://pypi.org/project/flyte/#history) and try running code locally.
The Flyte 2 backend is not yet available as open source, (but it will be soon!)
To run Flyte 2 code now you can apply for a [beta preview of the Union 2 backend](https://www.union.ai/beta).
When the Flyte 2 backend is released we will roll out a full contributor program just as we have for Flyte 1.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs ===
# Contributing docs and examples
> **π Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.
We welcome contributions to the docs and examples for both Flyte and Union.
This section will explain how the docs site works, how to author and build it locally, and how to publish your changes.
## The combined Flyte and Union docs site
As the primary maintainer and contributor of the open-source Flyte project, Union AI is responsible for hosting the Flyte documentation.
Additionally, Union AI is also the company behind the commercial Union.ai product, which is based on Flyte.
Since Flyte and Union.ai share a lot of common functionality, much of the documentation content is common between the two.
However, there are some significant differences between not only Flyte and Union.ai but also among the different Union.ai product offering (Serverless, BYOC, and Self-managed).
To effectively and efficiently maintain the documentation for all of these variants, we employ a single-source-of-truth approach where:
* All content is stored in a single GitHub repository, [`unionai/unionai-docs`](https://github.com/unionai/unionai-docs)
* All content is published on a single website, [`www.union.ai/docs`](/docs/v2/root/).
* The website has a variant selector at the top of the page that lets you choose which variant you want to view:
* Flyte OSS
* Union Serverless
* Union BYOC
* Union Self-managed
* There is also version selector. Currently two versions are available:
* v1 (the original docs for Flyte/Union 1.x)
* v2 (the new docs for Flyte/Union 2.0, which is the one you are currently viewing)
## Versions
The two versions of the docs are stored in separate branches of the GitHub repository:
* [`v1` branch](https://github.com/unionai/unionai-docs/tree/v1) for the v1 docs.
* [`main` branch](https://github.com/unionai/unionai-docs) for the v2 docs.
See **Contributing docs and examples > Versions** for more details.
## Common build infrastructure
The build infrastructure for the docs site (Hugo configuration, layouts, themes, build scripts, and Python tools) is maintained in a separate repository, [`unionai/unionai-docs-infra`](https://github.com/unionai/unionai-docs-infra), which is imported as a [git submodule](https://git-scm.com/book/en/v2/Git-Tools-Submodules) at `unionai-docs-infra/` in the `unionai-docs` repository.
This means both the `main` (v2) and `v1` content branches share the same build infrastructure.
Changes to the build system are made once in `unionai-docs-infra` and are picked up by both branches, keeping them in sync without duplicating build logic.
## Variants
Within each branch the multiple variants are supported by using conditional rendering:
* Each page of content has a `variants` front matter field that specifies which variants the page is applicable to.
* Within each page, rendering logic can be used to include or exclude content based on the selected variant.
The result is that:
* Content that is common to all variants is authored and stored once.
There is no need to keep multiple copies of the same content in-sync.
* Content specific to a variant is conditionally rendered based on the selected variant.
See **Contributing docs and examples > Variants** for more details.
## Both Flyte and Union docs are open source
Since the docs are now combined in one repository, and the Flyte docs are open source, the Union docs are also open source.
All the docs are available for anyone to contribute to: Flyte contributors, Union customers, and Union employees.
If you are a Flyte contributor, you will be contributing docs related to Flyte features and functionality, but in many cases these features and functionality will also be available in Union.
Because the docs site is a single source for all the documentation, when you make changes related to Flyte that are also valid for Union you do so in the same place.
This is by design and is a key feature of the docs site.
## Subpages
- **Contributing docs and examples > Quick start**
- **Contributing docs and examples > Variants**
- **Contributing docs and examples > Versions**
- **Contributing docs and examples > Authoring**
- **Contributing docs and examples > Shortcodes**
- **Contributing docs and examples > Redirects**
- **Contributing docs and examples > API docs**
- **Contributing docs and examples > LLM-optimized documentation**
- **Contributing docs and examples > Publishing**
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/quick-start ===
# Quick start
## Prerequisites
The docs site is built using the [Hugo](https://gohugo.io/) static site generator.
You will need to install it to build the site locally.
See [Hugo Installation](https://gohugo.io/getting-started/installing/).
## Clone the repository
Clone the [`unionai/unionai-docs`](https://github.com/unionai/unionai-docs) repository to your local machine.
The content is located in the `content/` folder in the form of Markdown files.
The hierarchy of the files and folders under `content/` directly reflect the URL and navigation structure of the site.
## Live preview
Next, set up the live preview by going to the root of your local repository check-out and copy the sample configuration file to `hugo.local.toml`:
```bash
cp unionai-docs-infra/hugo.local.toml~sample hugo.local.toml
```
This file contains the configuration for the live preview:
By default, it is set to display the `flyte` variant of the docs site along with enabling the flags `show_inactive`, `highlight_active`, and `highlight_keys` (more about these below)
Now you can start the live preview server by running:
```bash
make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the content you will see the changes reflected in the live preview.
## Distribution build
To build the site for distribution, run:
```bash
make dist
```
This will build the site locally just as it is built by the Cloudflare CI for production.
You can view the result of the build by running a local server:
```bash
make serve
```
This will start a local server at `http://localhost:9000` and serve the contents of the `dist/` folder. You can also specify a port number:
```bash
make serve PORT=
```
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/variants ===
# Variants
The docs site supports the ability to show or hide content based of the current variant selection.
There are separate mechanisms for:
* Including or excluding entire pages based on the selected variant.
* Conditional rendering of content within a page based on the selected variant using an if-then-like construct.
* Rendering keywords as variables that change based on the selected variant.
Currently, the docs site supports three variants:
- **Flyte OSS**: The open-source Flyte project.
- **BYOC**: The Union.ai product that is hosted on the customer's infrastructure but managed by Union AI.
- **Self-managed**: The Union.ai product that is hosted and managed by the customer.
Each variant is referenced in the page logic using its respective code name: `flyte`, `byoc`, or `selfmanaged`.
The available set of variants are defined in the `config..toml` files in the `unionai-docs-infra/` directory.
## Variants at the whole-page level
The docs site supports the ability to show or hide entire pages based of the selected variant.
Not all pages are available in all variants because features differ across the variants.
In the public website, if you are on page in one variant, and you change to a different variant, the page will change to the same page in the new variant *if it exists*.
If it does not exist, you will see a message indicating that the page is not available in the selected variant.
In the source Markdown, the presence or absence of a page in a given variant is governed by `variants` field in the front matter parameter of the page.
For example, if you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/unionai-docs/blob/main/content/community/contributing-docs.md), you will see the following front matter:
```markdown
---
title: Platform overview
weight: 1
variants: +flyte +byoc +selfmanaged
---
```
The `variants` field has the value:
`+flyte +byoc +selfmanaged`
The `+` indicates that the page is available for the specified variant.
In this case, the page is available for all three variants.
If you wanted to make the page available for only the `flyte` variant, you would change the `variants` field to:
`+flyte -byoc -selfmanaged`
In [live preview mode](./authoring#live-preview) with the `show_inactive` flag enabled, you will see all pages in the navigation tree, with the ones unavailable for the current variant grayed out.
As you can see, the `variants` field expects a space-separated list of keywords:
* The code names for the current variants are `flyte`, `byoc`, and `selfmanaged`.
* All supported variants must be included explicitly in every `variants` field with a leading `+` or `-`. There is no default behavior.
* The supported variants are configured in the `unionai-docs-infra/` directory in the files named `config..toml`.
## Conditional rendering within a page
Content can also differ *within a page* based on the selected variant.
This is done with conditional rendering using the `{{* variant */>}}` and `{{* key */>}}` [Hugo shortcodes](https://gohugo.io/content-management/shortcodes/).
### {{* variant */>}}
The syntax for the `{{* variant */>}}` shortcode is:
```markdown
{{* variant */>}}
...
{{* /variant */>}}
```
Where `` is a list the code name for the variants you want to show the content for.
Note that the variant construct can only directly contain other shortcode constructs, not plain Markdown.
In the most common case, you will want to use the `{{* markdown */>}}` shortcode (which can contain Markdown) inside the `{{* variant */>}}` shortcode to render Markdown content, like this:
```markdown
{{* variant byoc selfmanaged */>}}
{{* markdown */>}}
This content is only visible in the `byoc` and `selfmanaged` variants.
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
For more details on the `{{* variant */>}}` shortcode, see the [Shortcodes > `variant`](./shortcodes#variant).
### {{* key */>}}
The syntax for the `{{* key */>}}` shortcode is:
CODE2
Where `` is the name of the key you want to render.
For example, if you want to render the product name keyword, you would use:
CODE3
The available key names are defined in the [params.key] section of the `hugo.site.toml` configuration file in the root of the repository.
For example the `product_name` used above is defined in that file as
CODE4
Meaning that in any content that appears in the `flyte` variant of the site `{{* key product_name */>}}` shortcode will be replaced with `Flyte`, and in any content that appears in the `byoc` or `selfmanaged` variants, it will be replaced with `Union.ai`.
For more details on the `{{* key */>}}` shortcode, see the [Shortcodes > `key`](./shortcodes#key)
## Full example
Here is full example. If you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/unionai-docs/blob/main/content/community/contributing-docs/variants.md), you will see the following section:
```markdown
> **This text is visible in all variants.**
>
> {{* variant flyte */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `flyte` variant.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
> {{* variant byoc selfmanaged */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `byoc` and `selfmanaged` variants.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **{{* key product_full_name */>}}**
```
This Markdown source is rendered as:
> **This text is visible in all variants.**
>
> >
>
> **This text is only visible in the `flyte` variant.**
>
>
>
>
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **Flyte OSS**
If you switch between variants with the variant selector at the top of the page, you will see the content change accordingly.
## Adding a new variant
A variant is a term we use to identify a product or major section of the site.
Such variant has a dedicated token that identifies it, and all resources are
tagged to be either included or excluded when the variant is built.
> Adding new variants is a rare event and must be reserved when new products
> or major developments.
>
> If you are thinking adding a new variant is the way
> to go, please double-check with the infra admin to confirm before doing all
> the work below and waste your time.
### Location
When deploying, the variant takes a folder in the root
`https:////`
For example, if we have a variant `acme`, then when built the content goes to:
`https:///acme/`
### Creating a new variant
To create a new variant a few steps are required:
| File | Changes |
| ----------------------------------------- | -------------------------------------------------------------- |
| `hugo.site.toml` | Add to `params.variant_weights` and all `params.key` |
| `unionai-docs-infra/hugo.toml` | Add to `params.search` |
| `unionai-docs-infra/Makefile` | Add a new `make variant` to `dist` target |
| `.md` | Add either `+` or `-` to all content pages |
| `unionai-docs-infra/config..toml`| Create a new file and configure `baseURL` and `params.variant` |
### Testing the new variant
As you develop the new variant, it is recommended to have a `pre-release/` semi-stable
branch to confirm everything is working and the content looks good. It will also allow others
to collaborate by creating PRs against it (`base=pre-release/` instead of `main`)
without trampling on each other and allowing for parallel reviews.
Once the variant branch is correct, you merge that branch into main.
### Building (just) the variant
You can build the production version of the variant,
which will also trigger all the safety checks as well,
by invoking the variant build:
```bash
make variant VARIANT=
```
For example:
```bash
make variant VARIANT=byoc
```
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/versions ===
# Versions
In addition to the product variants, the docs site also supports multiple versions of the documentation.
The version selector is located at the top of the page, next to the variant selector.
Versions and variants are independent of each other, with the version being "above" the variant in the URL hierarchy.
The URL for version `v2` of the current page (the one you are one right now) in the Flyte variant is:
`/docs/v2/flyte//community/contributing-docs/versions`
while the URL for version `v1` of the same page is:
`/docs/v1/flyte//community/contributing-docs/versions`
### Versions are branches
The versioning system is based on long-lived Git branches in the `unionai/unionai-docs` GitHub repository:
- The `main` branch contains the latest version of the documentation. Currently, `v2`.
- Other versions of the docs are contained in branches named `vX`, where `X` is the major version number. Currently, there is one other version, `v1`.
## How to create an archive version
An "archive version" is a static snapshot of the site at a given point in time.
It is meant to freeze a specific version of the site for historical purposes,
such as preserving the content and structure of the site at a specific point in time.
### How to create an archive version
1. Create a new branch from `main` named `vX`, e.g. `v3`.
2. Add the version to the `VERSION` field in the `makefile.inc` file, e.g. `VERSION := v3`.
3. Add the version to the `versions` field in the `hugo.ver.toml` file, e.g. `versions = [ "v1", "v2", "v3" ]`.
> [!NOTE]
> **Important:** You must update the `versions` field in **ALL** published and archived versions of the site.
### Publishing an archive version
> [!NOTE]
> This step can only be done by a Union employee.
1. Update the `docs_archive_versions` in the `docs_archive_locals.tf` Terraform file
2. Create a PR for the changes
3. Once the PR is merged, run the production pipeline to activate the new version
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/authoring ===
# Authoring
## Getting started
Content is located in the `content` folder.
To create a new page, simply create a new Markdown file in the appropriate folder and start writing it!
## Target the right branch
Remember that there are two production branches in the docs: `main` and `v1`.
* **For Flyte or Union 1, create a branch off of `v1` and target your pull request to `v1`**
* **For Flyte or Union 2, create a branch off of `main` and target your pull request to `main`**
## Live preview
While editing, you can use Hugo's local live preview capabilities.
Simply execute
```bash
make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the preview will update automatically.
See [Publishing](./publishing) for how to set up your machine.
## Pull Requests + Site Preview
Pull requests will create a preview build of the site on CloudFlare.
Check the pull request for a dynamic link to the site changes within that PR.
## Page Visibility
This site uses variants, which means different "flavors" of the content.
For a given -age, its variant visibility is governed by the `variants:` field in the front matter of the page source.
For each variant you specify `+` to include or `-` to exclude it.
For example:
```markdown
---
title: My Page
variants: -flyte +byoc -selfmanaged
---
```
In this example the page will be:
* Included in Serverless and BYOC.
* Excluded from Flyte and Self-managed.
> [!NOTE]
> All variants must be explicitly listed in the `variants` field.
> This helps avoid missing or extraneous pages.
## Page order
Pages are ordered by the value of `weight` field (an integer >= 0) in the frontmatter of the page,
1. The higher the weight the lower the page sits in navigation ordering among its peers in the same folder.
2. Pages with no weight field (or `weight = 0`) will be ordered last.
3. Pages of the same weight will be sorted alphabetically by their title.
4. Folders are ordered among their peers (other folders and pages at the same level of the hierarchy) by the weight of their `_index.md` page.
For example:
```markdown
---
title: My Page
weight: 3
---
```
## Page settings
| Setting | Type | Description |
| ------------------ | ---- | --------------------------------------------------------------------------------- |
| `top_menu` | bool | If `true` the item becomes a tab at the top and its hierarchy goes to the sidebar |
| `sidebar_expanded` | bool | If `true` the section becomes expanded in the sidebar. Permanently. |
| `site_root` | bool | If `true` indicates that the page is the site landing page |
| `toc_max` | int | Maximum heading to incorporate in the right navigation table of contents. |
| `llm_readable_bundle` | bool | If `true`, generates a `section.md` bundle for this section. Requires `{{* llm-bundle-note */>}}` shortcode. See [LLM-optimized documentation](./llm-docs). |
## Conditional Content
The site has "flavors" of the documentation. We leverage the `{{* variant */>}}` tag to control
which content is rendered on which flavor.
Refer to [**Variants**](./shortcodes#variants) for detailed explanation.
## Warnings and Notices
You can write regular Markdown and use the notation below to create information and warning boxes:
```markdown
> [!NOTE] This is the note title
> You write the note content here. It can be
> anything you want.
```
Or if you want a warning:
```markdown
> [!WARNING] This is the title of the warning
> And here you write what you want to warn about.
```
## Special Content Generation
There are various short codes to generate content or special components (tabs, dropdowns, etc.)
Refer to [**Content Generation**](./shortcodes) for more information.
## Python Generated Content
You can generate pages from markdown-commented Python files.
At the top of your `.md` file, add:
```markdown
---
layout: py_example
example_file: /path/to/your/file.py
run_command: union run --remote tutorials//path/to/your/file.py main
source_location: https://www.github.com/unionai/unionai-examples/tree/main/tutorials/path/to/your/file.py
---
```
Where the referenced file looks like this:
```python
# # Credit Default Prediction with XGBoost & NVIDIA RAPIDS
#
# In this tutorial, we will use NVIDIA RAPIDS `cudf` DataFrame library for preprocessing
# data and XGBoost, an optimized gradient boosting library, for credit default prediction.
# We'll learn how to declare NVIDIA `A100` for our training function and `ImageSpec`
# for specifying our python dependencies.
# {{run-on-union}}
# ## Declaring workflow dependencies
#
# First, we start by importing all the dependencies that is required by this workflow:
import os
import gc
from pathlib import Path
from typing import Tuple
import fsspec
from flytekit import task, workflow, current_context, Resources, ImageSpec, Deck
from flytekit.types.file import FlyteFile
from flytekit.extras.accelerators import A100
```
Note that the text content is embedded in comments as Markdown, and the code is normal python code.
The generator will convert the markdown into normal page text content and the code into code blocks within that Markdown content.
### Run on Union Instructions
We can add the run on Union instructions anywhere in the content.
Annotate the location you want to include it with `{{run-on-union}}`. Like this:
CODE6
The resulting **Run on Union** section in the rendered docs will include the run command and source location,
specified as `run_command` and `source_location` in the front matter of the corresponding `.md` page.
## Jupyter Notebooks
You can also generate pages from Jupyter notebooks.
At the top of your.md file, add:
---
jupyter_notebook: /path/to/your/notebook.ipynb
---
Jupyter notebook conversion is handled automatically as part of the production build:
CODE7
The conversion tool is located at `unionai-docs-infra/tools/jupyter_generator`.
**Committing the change:** When the PR is pushed, a CI check verifies consistency between the notebook and its generated content. Please ensure that if you change the notebook, you run `make dist` to update the generated page.
## Mapped Keys (`{{* key */>}}`)
Key is a very special command that allows us to define mapped values to a variant.
For example, the product name changes if it is Flyte, Union BYOC, etc. For that,
we can define a single key `product_full_name` and map it to reflect automatically,
without the need to `if variant` around it.
Please refer to [{{* key */>}} shortcode](./shortcodes#key) for more details.
## Mermaid Graphs
To embed Mermaid diagrams in a page, insert the code inside a block like this:
CODE8
Also add `mermaid: true` to the top of your page to enable rendering.
> [!NOTE]
> You can use [Mermaid's playground](https://www.mermaidchart.com/play) to design diagrams and get the code
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/shortcodes ===
# Shortcodes
This site has special blocks that can be used to generate code for Union.
> [!NOTE]
> You can see examples by running the dev server and visiting
> [`http://localhost:1313/__docs_builder__/shortcodes/`](`http://localhost:1313/__docs_builder__/shortcodes/`).
> Note that this page is only visible locally. It does not appear in the menus or in the production build.
>
> If you need instructions on how to create the local environment and get the
> `localhost:1313` server running, please refer to the [local development guide](./publishing).
## How to specify a "shortcode"
The shortcode is a string that is used to generate the HTML that is displayed.
You can specify parameters, when applicable, or have content inside it, if applicable.
> [!NOTE]
> If you specify content, you have to have a close tag.
Examples:
* A shortcode that just outputs something
```markdown
{{* key product_name */>}}
```
* A shortcode that has content inside
```markdown
{{* markdown */>}}
* You markdown
* goes here
{{* /markdown */>}}
```
* A shortcode with parameters
```markdown
{{* link-card target="union-sdk" icon="workflow" title="Union SDK" */>}}
The Union SDK provides the Python API for building Union workflows and apps.
{{* /link-card */>}}
```
> [!NOTE]
> If you're wondering why we have a `{{* markdown */>}}` when we can generate markdown at the top level, it is due to a quirk in Hugo:
> * At the top level of the page, Hugo can render markdown directly, interspersed with shortcodes.
> * However, *inside* a container shortcode, Hugo can only render *either* other shortcodes *or* Markdown.
> * The `{{* markdown */>}}` shortcode is designed to contain only Markdown (not other shortcodes).
> * All other container shortcodes are designed to contain only other shortcodes.
## Variants
The big difference of this site, compared to other documentation sites, is that we generate multiple "flavors" of the documentation that are slightly different from each other. We are calling these "variants."
When you are writing your content, and you want a specific part of the content to be conditional to a flavor, say "BYOC", you surround that with `variant`.
>[!NOTE]
> `variant` is a container, so inside you will specify what you are wrapping.
> You can wrap any of the shortcodes listed in this document.
Example:
```markdown
{{* variant byoc selfmanaged */>}}
{{* markdown */>}}
**The quick brown fox signed up for Union!**
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
## Component Library
### `{{* audio */>}}`
Generates an audio media player.
### `{{* grid */>}}`
Creates a fixed column grid for lining up content.
### `{{* variant */>}}`
Filters content based on which flavor you're seeing.
### `{{* link-card */>}}`
A floating, clickable, navigable card.
### `{{* markdown */>}}`
Generates a markdown block, to be used inside containers such as `{{* dropdown */>}}` or `{{* variant */>}}`.
### `{{* multiline */>}}`
Generates a multiple line, single paragraph. Useful for making a multiline table cell.
### `{{* tabs */>}}` and `{{* tab */>}}`
Generates a tab panel with content switching per tab.
### `{{* key */>}}`
Outputs one of the pre-defined keywords.
Enables inline text that differs per-variant without using the heavy-weight `{{* variant>}}...{{* /variant */>}}` construct.
Take, for example, the following:
```markdown
The {{* key product_name */>}} platform is awesome.
```
In the Flyte variant of the site this will render as:
> The Flyte platform is awesome.
While, in the BYOC, Self-managed and Serverless variants of the site it will render as:
> The Union.ai platform is awesome.
You can add keywords and specify their value, per variant, in `hugo.site.toml`:
```toml
[params.key.product_full_name]
flyte = "Flyte"
byoc = "Union BYOC"
selfmanaged = "Union Self-managed"
```
#### List of available keys
| Key | Description | Example Usage (Flyte β Union) |
| ----------------- | ------------------------------------- | ---------------------------------------------------------------------- |
| default_project | Default project name used in examples | `{{* key default_project */>}}` β "flytesnacks" or "default" |
| product_full_name | Full product name | `{{* key product_full_name */>}}` β "Flyte OSS" or "Union.ai BYOC" |
| product_name | Short product name | `{{* key product_name */>}}` β "Flyte" or "Union.ai" |
| product | Lowercase product identifier | `{{* key product */>}}` β "flyte" or "union" |
| kit_name | SDK name | `{{* key kit_name */>}}` β "Flytekit" or "Union" |
| kit | Lowercase SDK identifier | `{{* key kit */>}}` β "flytekit" or "union" |
| kit_as | SDK import alias | `{{* key kit_as */>}}` β "fl" or "union" |
| kit_import | SDK import statement | `{{* key kit_import */>}}` β "flytekit as fl" or "union" |
| kit_remote | Remote client class name | `{{* key kit_remote */>}}` β "FlyteRemote" or "UnionRemote" |
| cli_name | CLI tool name | `{{* key cli_name */>}}` β "Pyflyte" or "Union" |
| cli | Lowercase CLI tool identifier | `{{* key cli */>}}` β "pyflyte" or "union" |
| ctl_name | Control tool name | `{{* key ctl_name */>}}` β "Flytectl" or "Uctl" |
| ctl | Lowercase control tool identifier | `{{* key ctl */>}}` β "flytectl" or "uctl" |
| config_env | Configuration environment variable | `{{* key config_env */>}}` β "FLYTECTL_CONFIG" or "UNION_CONFIG" |
| env_prefix | Environment variable prefix | `{{* key env_prefix */>}}` β "FLYTE" or "UNION" |
| docs_home | Documentation home URL | `{{* key docs_home */>}}` β "/docs/flyte" or "/docs/byoc" |
| map_func | Map function name | `{{* key map_func */>}}` β "map_task" or "map" |
| logo | Logo image filename | `{{* key logo */>}}` β "flyte-logo.svg" or "union-logo.svg" |
| favicon | Favicon image filename | `{{* key favicon */>}}` β "flyte-favicon.ico" or "union-favicon.ico" |
### `{{* download */>}}`
Generates a download link.
Parameters:
- `url`: The URL to download from
- `filename`: The filename to save the file as
- `text`: The text to display for the download link
Example:
```markdown
{{* download "/_static/public/public-key.txt" "public-key.txt" */>}}
```
### `{{* docs_home */>}}`
Produces a link to the home page of the documentation for a specific variant.
Example:
```markdown
[See this in Flyte]({{* docs_home flyte>}}/wherever/you/want/to/go/in/flyte/docs)
```
### `{{* py_class_docsum */>}}`, `{{* py_class_ref */>}}`, and `{{* py_func_ref */>}}`
Helper functions to track Python classes in Flyte documentation, so we can link them to
the appropriate documentation.
Parameters:
- name of the class
- text to add to the link
Example:
```markdown
Please see {{* py_class_ref flyte.core.Image */>}} for more details.
```
### `{{* icon name */>}}`
Uses a named icon in the content.
Example:
```markdown
[Download {{* icon download */>}}](/download)
```
### `{{* code */>}}`
Includes a code snippet or file.
Parameters:
- `file`: The path to the file to include.
- `fragment`: The name of the fragment to include.
- `from`: The line number to start including from.
- `to`: The line number to stop including at.
- `lang`: The language of the code snippet.
- `show_fragments`: Whether to show the fragment names in the code block.
- `highlight`: Whether to highlight the code snippet.
The examples in this section uses this file as base:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment entrypoint}}
```
*Source: /_static/__docs_builder__/sample.py*
Link to [/_static/__docs_builder__/sample.py](/_static/__docs_builder__/sample.py)
#### Including a section of a file: `{{docs-fragment}}`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" fragment=entrypoint lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment entrypoint}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a file with a specific line range: `from` and `to`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" from=2 to=4 lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment entrypoint}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a whole file
Simply specify no filters, just the `file` attribute:
```markdown
{{* code file="/_static/__docs_builder__/sample.py" */>}}
```
> [!NOTE]
> Note that without `show_fragments=true` the fragment markers will not be shown.
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment entrypoint}}
```
*Source: /_static/__docs_builder__/sample.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/redirects ===
# Redirects
We use Cloudflare's Bulk Redirect to map URLs that moved to their new location,
so the user does not get a 404 using the old link.
The direct files are in CSV format, with the following structure:
`,,302,TRUE,FALSE,TRUE,TRUE`
- ``: the URL without `https://`
- ``: the full URL (including `https://`) to send the user to
Redirects are recorded in the `unionai-docs-infra/redirects.csv` file.
To take effect, this file must be applied to the production environment on CloudFlare by a Union employee.
If you need to add a new redirect, please create a pull request with the change to `redirect.csv` and a note indicating that you would like to have it applied to production.
## `docs.union.ai` redirects
For redirects from the old `docs.union.ai` site to the new `www.union.ai/docs` site, we use the original request URL. For example:
|
|-|-|
| Request URL | `https://docs.union.ai/administration` |
| Target URL | `/docs/v1/byoc//user-guide/administration` |
| Redirect Entry | `docs.union.ai/administration,/docs/v1/byoc//user-guide/administration,302,TRUE,FALSE,TRUE,TRUE` |
## `docs.flyte.org` redirects
For directs from the old `docs.flyte.org` to the new `www.union.ai/docs`, we replace the `docs.flyte.org` in the request URL with the special prefix `www.union.ai/_r_/flyte`. For example:
|
|-|-|
| Request URL | `https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Converted request URL | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Target URL | `/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/` |
| Redirect Entry | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html,/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/,302,TRUE,FALSE,TRUE,TRUE` |
The special prefix is used so that we can include both `docs.union.ai` and `docs.flyte.org` redirects in the same file and apply them on the same domain (`www.union.ai`).
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/api-docs ===
# API docs
You can import Python APIs and host them on the site. To do that you will use
the `unionai-docs-infra/tools/api_generator` to parse and create the appropriate markdown.
Please refer to [`api_generator/README`](https://github.com/unionai/unionai-docs-infra/blob/main/tools/api_generator/README.md) for more details.
## API naming convention
All the buildable APIs are defined in Makefiles of the form:
`unionai-docs-infra/Makefile.api.`
To build it, run `make -f unionai-docs-infra/Makefile.api.` and observe the setup
requirements in the `README.md` file above. Alternatively, `make update-api-docs` will regenerate all API docs.
## Package Resource Resolution
When scanning the packages we need to know when to include or exclude an object
(class, function, variable) from the documentation. The parser will follow this
workflow to decide, in order, if the resource must be in or out:
1. `__all__: List[str]` package-level variable is present: Only resources
listed will be exposed. All other resources are excluded.
Example:
```python
from http import HTTPStatus, HTTPMethod
__all__ = ["HTTPStatus", "LocalThingy"]
class LocalThingy:
...
class AnotherLocalThingy:
...
```
In this example only `HTTPStatus` and `LocalThingy` will show in the docs.
Both `HTTPMethod` and `AnotherLocalThingy` are ignored.
2. If `__all__` is not present, these rules are observed:
- All imported packages are ignored
- All objects starting with `_` are ignored
Example:
```python
from http import HTTPStatus, HTTPMethod
class _LocalThingy:
...
class AnotherLocalThingy:
...
def _a_func():
...
def b_func():
...
```
In this example only `AnotherLocalThingy` and `b_func` will show in the docs.
Neither none of the imports nor `_LocalThingy` will show in the documentation.
## Tips and Tricks
1. If you either have no resources without a `_` nor an `__all__` to
export blocked resources (imports or starting with `_`, the package will have no content and thus will not be generated.
2. If you want to export something you `from ___ import ____` you _must_
use `__all__` to add the private import to the public list.
3. If all your methods follow the Python convention of everything private starts
with `_` and everything you want public does not, you do not need to have a
`__all__` allow list.
## Enabling auto-linking for plugins
When you generate API documentation for a plugin, the build process creates two data files that enable automatic linking from documentation to API references:
- `data/{name}.yaml` - Hugo data file for server-side code block linking
- `static/{name}-linkmap.json` - JSON file for client-side inline code linking
For plugins, use the `--short-names` flag when generating API docs (already enabled in `Makefile.api.plugins`). This generates both fully qualified names (`flyteplugins.wandb.wandb_init`) and short names (`wandb_init`) in the linkmap, allowing docs to reference APIs without the full package path.
To enable auto-linking for a new plugin, you need to register these files in two places:
### 1. Server-side code block linking
Edit `unionai-docs-infra/layouts/partials/autolink-python.html` and add your plugin's data file to the merge chain:
```go-html-template
{{- /* Load and merge all API data sources */ -}}
{{- $flyteapi := dict "identifiers" (dict) "methods" (dict) "packages" (dict) -}}
{{- with site.Data.flytesdk -}}
{{- $flyteapi = merge $flyteapi (dict "identifiers" (.identifiers | default dict) "methods" (.methods | default dict) "packages" (.packages | default dict)) -}}
{{- end -}}
{{- with site.Data.wandb -}}
{{- $flyteapi = merge $flyteapi (dict "identifiers" (merge $flyteapi.identifiers (.identifiers | default dict)) "methods" (merge $flyteapi.methods (.methods | default dict)) "packages" (merge $flyteapi.packages (.packages | default dict))) -}}
{{- end -}}
{{- /* Add your plugin here following the same pattern */ -}}
```
### 2. Client-side inline code linking
Edit `static/js/inline-code-linker.js` and add your plugin's linkmap file to the `linkmapFiles` array:
```javascript
const linkmapFiles = ['flytesdk-linkmap.json', 'wandb-linkmap.json'];
// Add your plugin's linkmap file here, e.g., 'myplugin-linkmap.json'
```
### How auto-linking works
Once configured, the following will be automatically linked:
- **Code blocks**: Python code in fenced code blocks will have API references linked. For example, `wandb_init()` in a Python code block will link to its API documentation.
- **Inline code**: Inline code like `` `wandb_init()` `` will be linked. The `@` prefix for decorators and `()` suffix for functions are automatically stripped for matching.
The linkmap files contain mappings from identifiers to their API documentation URLs. Both short names (e.g., `wandb_init`) and fully qualified names (e.g., `flyteplugins.wandb.wandb_init`) are supported if included in the linkmap.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/llm-docs ===
# LLM-optimized documentation
The build pipeline generates LLM-optimized versions of every page and several index files,
designed for use by AI coding agents and AI search engines.
## Output files
The `make dist` command (specifically the `make llm-docs` step) produces the following
in each variant's `dist/docs/v2/{variant}/` directory:
| File | Description |
|------|-------------|
| `page.md` | Per-page LLM-optimized Markdown, generated alongside every `index.html`. Links between pages use relative `page.md` references, then are converted to absolute URLs in a final pass. |
| `section.md` | A single-file bundle concatenating all pages in a section. Only generated for sections with `llm_readable_bundle: true` in frontmatter. Internal links become hierarchical bold references; external links become absolute URLs. |
| `llms.txt` | Page index listing every page grouped by section, with H2/H3 headings for discoverability. Sections with bundles are marked with a "Section bundle" link. |
| `llms-full.txt` | The entire documentation for one variant as a single file, with all internal links converted to hierarchical bold references (e.g. `**Configure tasks > Resources**`). |
### Discovery hierarchy
```
dist/docs/llms.txt # Root: lists versions
dist/docs/v2/llms.txt # Version: lists variants
dist/docs/v2/{variant}/llms.txt # Variant: page index with headings
dist/docs/v2/{variant}/llms-full.txt # Full consolidated doc
dist/docs/v2/{variant}/**/page.md # Per-page Markdown
dist/docs/v2/{variant}/**/section.md # Section bundles (where enabled)
```
## How `page.md` files are generated
1. Hugo builds the site into `dist/` and also outputs a Markdown format into `tmp-md/`.
2. `process_shortcodes.py` reads from `tmp-md/`, resolves all shortcodes (variants, code includes, tabs, notes, etc.), and writes the result as `page.md` alongside each `index.html`.
3. `fix_internal_links_post_processing()` converts all internal links in `page.md` files to point to other `page.md` files using relative paths.
4. `build_llm_docs.py` then enhances subpage listings with H2/H3 headings, generates section bundles, converts all relative links to absolute URLs, and creates the `llms.txt` and `llms-full.txt` index files.
## Enabling section bundles
To produce a `section.md` bundle for a documentation section:
1. Add `llm_readable_bundle: true` to the frontmatter of the section's `_index.md`:
```yaml
---
title: Configure tasks
weight: 8
variants: +flyte +byoc +selfmanaged
llm_readable_bundle: true
---
```
2. Add the `{{* llm-bundle-note */>}}` shortcode in the body of the same `_index.md`,
right after the page title:
```markdown
# Configure tasks
{{* llm-bundle-note */>}}
As we saw in ...
```
This renders a note on the HTML page pointing readers to the `section.md` file.
Both the frontmatter parameter and the shortcode are required.
A CI check (`check-llm-bundle-notes`) verifies they are always in sync.
## The `llms-full.txt` link conversion
In `llms-full.txt`, all internal `page.md` links are converted to hierarchical bold references:
* Cross-page: `[Resources](../resources/page.md)` becomes `**Configure tasks > Resources**`
* Same-page anchor: `[Image building](#image-building)` becomes `**Container images > Image building**`
* External links (`http`/`https`) are preserved unchanged.
This makes the file self-contained with no broken references.
## Regenerating
LLM documentation is regenerated automatically as part of `make dist`.
To regenerate only the LLM files without a full rebuild:
```
make llm-docs
```
New pages are included automatically if linked via `## Subpages` in their parent's Hugo output.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/publishing ===
# Publishing
## Requirements
1. Hugo (https://gohugo.io/)
```bash
brew install hugo
```
2. A preferences override file with your configuration
The tool is flexible and has multiple knobs. Please review `unionai-docs-infra/hugo.local.toml~sample`, and configure to meet your preferences.
```bash
cp unionai-docs-infra/hugo.local.toml~sample hugo.local.toml
```
3. Make sure you review `hugo.local.toml`.
## Managing the Tutorial Pages
The tutorials are maintained in the [unionai/unionai-examples](https://github.com/unionai/unionai-examples) repository and is imported as a git submodule in the `unionai-examples`
directory.
To initialize the submodule on a fresh clone of this repository, run:
```
$ make init-examples
```
To update the submodule to the latest `main` branch, run:
```
$ make update-examples
```
## Building and running locally
```
$ make dev
```
## Developer Experience
This will launch the site in development mode.
The changes are hot reloaded: just change in your favorite editor and it will refresh immediately on the browser.
### Controlling Development Environment
You can change how the development environment works by settings values in `hugo.local.toml`. The following settings are available:
* `variant` - The current variant to display. Change this in 'hugo.local.toml', save, and the browser will refresh automatically
with the new variant.
* `show_inactive` - If 'true', it will show all the content that did not match the variant.
This is useful when the page contains multiple sections that vary with the selected variant,
so you can see all at once.
* `highlight_active` - If 'true', it will also highlight the *current* content for the variant.
* `highlight_keys` - If 'true'', it highlights replacement keys and their values
### Changing 'variants'
Variants are flavors of the site (that you can change at the top).
During development, you can render any variant by setting it in `hugo.local.toml`:
```
variant = "byoc"
```
We call this the "active" variant.
You can also render variant content from other variants at the same time as well as highlighting the content of your active variant:
To show the content from variants other than the currently active one set:
```
show_inactive = true
```
To highlight the content of the currently active variant (to distinguish it from common content that applies to all variants), set:
```
highlight_active = true
```
> You can create your own copy of `hugo.local.toml` by copying from `unionai-docs-infra/hugo.local.toml~sample` to get started.
## Troubleshootting
### Identifying Problems: Missing Content
Content may be hidden due to `{{* variant */>}}` blocks. To see what's missing,
you can adjust the variant show/hide in development mode.
For a production-like look set:
show_inactive = false
highlight_active = false
For a full-developer experience, set:
show_inactive = true
highlight_active = true
### Identifying Problems: Page Visibility
The developer site will show you in red any pages missing from the variant.
For a page to exist in the variant (or be excluded, you have to pick one), it must be listed in the `variants:` at the top of the file.
Clicking on the red page will give you the path you must add to the appropriate variant in the YAML file and a link with guidance.
Please refer to [Authoring](./authoring) for more details.
## Building Production
```
$ make dist
```
This will build all the variants and place the result in the `dist` folder.
### Testing Production Build
You can run a local web server and serve the `dist/` folder. The site must behave correctly, as it would be in its official URL.
To start a server:
```
$ make serve [PORT=]
```
If specified without parameters, defaults to PORT=9000.
Example:
```
$ make serve PORT=4444
```
Then you open the browser on `http://localhost:` to see the content. In the example above, it would be `http://localhost:4444/`
=== PAGE: https://www.union.ai/docs/v2/flyte/release-notes ===
# Release notes
## March 2026
### :wrench: Extended Idle Timeout for Panel Apps
Panel apps now support longer idle times for websocket connections, with session token expiration increased to 3 hours. New parameters for managing unused session lifetimes improve stability of long-running applications.
### :wrench: Plugin Variants Documentation
The new `--plugin-variants` flag in `flyte gen docs` generates variant-scoped CLI documentation. Plugin-contributed CLI commands are wrapped in Hugo `{{* variant */>}}` shortcodes, so core commands appear unconditionally while plugin commands are shown only in specified variants (e.g., `byoc`, `selfmanaged`).
### :rocket: Google Gemini Plugin Integration
You can now integrate Google's Gemini API with Flyte using the new `function_tool` decorator to automatically convert Flyte tasks into Gemini agent tools. Both synchronous and asynchronous operations are supported.
```python
import flyte
from flyteplugins.gemini import function_tool, run_agent
env = flyte.TaskEnvironment("gemini-agent")
@env.task
async def get_weather(city: str) -> str:
return f"The weather in {city} is sunny."
# Run Gemini agent with a tool
async def agent_task(prompt: str):
tools = [function_tool(get_weather)]
return await run_agent(prompt=prompt, tools=tools, model="gemini-2.5-flash")
```
### :hammer: Forced Image Build Caching
You can now force a rebuild of images by setting `force=True`, which skips the existence check and rebuilds even if the image already exists. When using the remote image builder, this also sets `overwrite_cache=True`.
```python
import flyte
image = flyte.Image("your_image")
result = await flyte.build.aio(image, force=True)
```
### :computer: LLM-Powered Code Generation
The new `flyteplugins-codegen` plugin generates code from natural language prompts, runs tests, and iterates in isolated sandboxes using LLMs.
```python
from flyteplugins.codegen import AutoCoderAgent
agent = AutoCoderAgent(model="gpt-4.1", name="data-processor", resources=flyte.Resources(cpu=1, memory="1Gi"))
result = await agent.generate.aio(
prompt="Process the CSV data to calculate total revenue and units.",
samples={"sales": csv_file},
outputs={"total_revenue": float, "total_units": int},
)
```
### :wrench: Updated AI Plugin Examples
Fixed and improved plugin examples for working with OpenAI and Anthropic in Flyte 2.0, using updated versions of `flyteplugins-openai` and `flyteplugins-anthropic`.
```python
from flyteplugins.openai.agents import function_tool
agent_env = flyte.TaskEnvironment(
"openai-agent",
resources=flyte.Resources(cpu=1),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
)
@function_tool
@agent_env.task
async def get_bread() -> str:
await asyncio.sleep(1)
return "bread"
```
### :wrench: Debug Mode Integration
The Flyte SDK now supports a `--debug` flag to initiate tasks in VS Code debug mode from the CLI or Python interface. Specify `debug=True` in `flyte.with_runcontext` to attach a VS Code debugger during task execution.
```python
import flyte
env = flyte.TaskEnvironment(name="debug_example")
@env.task
def say_hello(name: str) -> str:
greeting = f"Hello, {name}!"
print(greeting)
return greeting
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.with_runcontext(debug=True).run(say_hello, name="World")
print(run.name)
print("Run url", run.url)
print("Waiting for debug url...")
print("Debug url", run.get_debug_url())
```
### :sparkles: Improved CLI Enum Support
The Flyte CLI now supports `EnumParamType`, allowing you to pass enum names directly (e.g., `--color=GREEN`) instead of requiring internal values.
### :memo: Programmatic Log Access
You can now access logs programmatically using the `get_logs()` method on `remote.Run` and `remote.Action`. This returns an iterator over log lines with support for asynchronous processing via `.aio()`, filtering system-generated logs, and including timestamps.
### :zap: Simplified PyTorch Example Setup
PyTorch environment setup is simplified: specify `flyteplugins-pytorch` directly via `with_pip_packages` instead of the internal `PythonWheels` API.
### :chart_with_upwards_trend: Distributed Training Evaluation
Flyte now supports distributed training with callback-driven evaluation. `EvalOnCheckpointCallback` automatically triggers evaluation tasks after each training checkpoint, running evaluations in parallel with training and monitoring convergence. Upon convergence, a stop signal gracefully halts training.
### :zap: Improved Benchmark Flexibility
The benchmark script for large I/O operations has been refactored. CPU and memory allocations are now parameterizable, file and directory tests can be run independently, and HTML report generation handles missing data gracefully.
### :computer: CLI Project Management
You can now create, update, and manage Flyte projects directly from the CLI, including setting IDs, names, descriptions, labels, and archive status.
```bash
# Example usage
flyte create project --id my_project_id --name "My Project" --description "Project description" -l team=dev -l env=prod
flyte update project my_project_id --archive
flyte get project --archived
```
### :robot: Anthropic Claude Integration
You can now integrate Flyte tasks as tools for Anthropic Claude agents. Define tasks in Flyte and convert them into Claude tool definitions using the `function_tool` utility.
### :hourglass_flowing_sand: Panel App Enhancements
The Flyte SDK panel app now uses a threaded asynchronous execution model, so actions like code execution no longer block the interface. Reo.Dev tracking integration provides monitoring capabilities.
### :gear: AWS Config File Support
Flyte now supports S3 authentication via the `AWS_CONFIG_FILE` environment variable. When both `AWS_PROFILE` and `AWS_CONFIG_FILE` are set, Flyte uses a boto3-backed credential provider for profile-based authentication.
### :sparkles: Improved Task Execution Reliability
Flyte now automatically uses `task.aio()` for both synchronous and asynchronous tasks, ensuring consistent execution through the Flyte controller. The previous fallback to `asyncio.to_thread()` for synchronous tasks has been removed.
### :wrench: Enhanced Action Service Integration
You can now attach custom gRPC headers when interacting with the Actions service, enabling consistent request metadata for routing and integration in distributed environments.
### :rocket: Async Training with Early Stopping
A new ML pattern example runs asynchronous training with periodic evaluations using Flyte's durable task management. The training task saves checkpoints asynchronously while evaluation tasks assess convergence, gracefully stopping training when convergence is detected.
```python
async def train(checkpoint_dir: str, total_epochs: int, seconds_per_epoch: float) -> File:
# Training logic
pass
async def evaluate(checkpoint_file: File, eval_round: int, convergence_loss: float) -> bool:
# Evaluation logic
pass
async def main(total_epochs, seconds_per_epoch, convergence_loss, eval_interval_seconds, max_eval_rounds):
# Orchestration logic
pass
```
Use `flyte run examples/ml/async_train_eval.py` to execute this pattern locally.
### :wrench: Improved Include Path Handling
Flyte now correctly resolves include paths relative to the app directory during deployment. Previously, include paths that escaped the app script's directory caused deployment failures due to invalid tar entries.
### :zap: Enhanced Retry Management
Task retries during local runs now support exponential backoff and detailed tracking of retry attempts, allowing recovery from transient errors. Retry visibility is improved in both the controller logic and the terminal UI.
### :zap: Improved Module Loading
The Flyte SDK's module loading now respects `.gitignore` and standard ignore rules, excluding directories like `.venv` and `__pycache__`.
### :zap: Dynamic Batching for Improved GPU Utilization
New `DynamicBatcher` and `TokenBatcher` classes allow concurrent tasks to share a single GPU, improving throughput for use cases like large-scale inference. An example demonstrates `TokenBatcher` for inference tasks with reusable containers.
### :sparkles: Run Cache Disabling
You can now disable run-level task result caching. When caching is disabled for a specific run, no cache hits are reported and cache operations are bypassed. The TUI reflects this with a clear indication that caching is disabled.
### :computer: Vim Key Navigation for TUI
The TUI (`FlyteTUIApp` and `ExploreTUIApp`) now supports Vim keys `j` and `k` for cursor movement in the `ActionTreeWidget` and `RunsTable`.
### :sparkles: Clickable Image Build URLs
Image URIs in TaskMetadata are now clickable in the Union frontend, linking directly to the Flyte run that built the image.
### :sparkles: Enhanced Run Filters
You can now filter runs and actions by project, domain, and creation/update time ranges. The new `TimeFilter` class supports filtering by `created_at` and `updated_at` timestamps, and filters are available through both the SDK and the CLI.
```python
from flyte.remote import TimeFilter
# Example usage to fetch runs created after a specific date
runs = Run.listall(
project="my-project",
created_at=TimeFilter(after="2026-03-01")
)
```
### :wrench: Simplified Dependency Management
`UVProject`'s `dependencies_only` mode now copies only the `pyproject.toml` files of each editable dependency instead of the entire directory, reducing build context size and speeding up image builds.
### :robot: MLE Agent Enhancements
Two new agents β the MLE Orchestrator Agent and the MLE Tool Builder Agent β use LLMs to automatically generate orchestration and processing code. They create, execute, and iteratively optimize ML models in an isolated sandbox environment with configurable computing resources.
### :sparkles: Improved Task Command Initialization
The Flyte CLI now initializes configuration when listing or resolving task commands via `TaskPerFileGroup`, preventing failures for config-dependent operations.
```python
import flyte
from flyte.io import File
env = flyte.TaskEnvironment(name="example_env")
@env.task
async def test_file(project: str, input_file: File) -> str:
return f"Got input {project=}, {input_file=}"
```
### :zap: New Example Applications & Bug Fixes
New example applications added:
- Distributed training using async tasks
- MNIST model handling with PyTorch
- Agent workflows with LangGraph & Gemini API
Also includes a bug fix for scaling metric serialization.
### :gear: Phase Transitions Tracking
You can now view phase transition details for actions, showing time spent in each phase (QUEUED, INITIALIZING, RUNNING, etc.). Use the `get_phase_transitions` method and properties like `queued_time` and `running_time` to identify bottlenecks programmatically.
```python
action = Action.get(run_name="my-run", name="my-action")
details = action.details()
transitions = details.get_phase_transitions()
for t in transitions:
print(f"{t.phase}: {t.duration.total_seconds()}s")
```
### :wrench: Multiple Source Files Support
`with_source_file` now accepts a list of file paths, allowing multiple files in a single image layer. An error is raised if duplicate filenames target the same location.
```python
from flyte._image import Image
from pathlib import Path
# Example usage with two different files
img = Image.from_debian_base(name="my-image").with_source_file([Path("a.py"), Path("b.py")])
```
### :package: Simplified Code Bundling
The new `with_code_bundle()` method packages source code into Docker images. When `copy_style` is set to `"none"` in `with_runcontext()` or during `flyte deploy`, source code is automatically baked into the image. Use `"loaded_modules"` to include specific Python modules or `"all"` for entire directories.
### :wrench: Improved Error Messaging for Deployment
When using a `src/` layout, the "Duplicate environment name" error during deployment now hints at the `--root-dir` option to help resolve dual-import issues.
```python
# New deployment configuration example
flyte deploy --dry-run --recursive --root-dir src src/my_module
```
### :wrench: Improved Debugging for Reusable Tasks
Reusable tasks now automatically disable debugging. Previously, debugging was enabled by default, which could cause issues with reusable tasks.
### :sparkles: JSONL Plugin Support
The new JSONL plugin adds `JsonlFile` and `JsonlDir` types for Flyte workflows. It supports async and sync read/write operations with optional `zstd` compression, using `orjson` for fast serialization.
```python
from flyteplugins.jsonl import JsonlFile, JsonlDir
# Example usage of JsonlFile
@env.task
async def process_file(f: JsonlFile):
async for record in f.iter_records():
print(record)
# Example usage of JsonlDir for sharded directories
@env.task
async def process_dir(d: JsonlDir):
async for record in d.iter_records():
print(record)
```
## February 2026
### :sparkles: JSON Schema Enhancement
Flyte now accurately converts Python types to JSON Schemas by leveraging Flyte's internal type system. Previously, certain types like `Literal["C", "F"]` were incorrectly mapped. Now, input schemas for Flyte tasks reflect precise JSON Schemas, improving integrations with tools like Anthropic's Claude.
```python
# Example: Converting Literal to JSON Schema correctly
def my_func(unit: Literal["C", "F"]) -> str:
return unit
schema = NativeInterface.from_callable(my_func).json_schema
assert schema["properties"]["unit"] == {"type": "string", "enum": ["C", "F"]}
```
### :calculator: Panel Calculator Example
A new example showcases a calculator app embedded in a Panel interface using Flyte's `AppEnvironment`, demonstrating how to build interactive web-based UIs with Flyte.
### :sparkles: Spark Plugin Update
The `flyteplugins-spark` dependency has been updated to `>=2.0.0`, moving away from pre-release versions.
### :lock: Secure Package Specification
Package version constraints like `apache-airflow<=3.0.0` are now automatically quoted in generated Dockerfiles. Previously, unquoted constraints could cause incorrect shell interpretation and build failures.
### :zap: Enum Name Acceptance in CLI
The Flyte CLI now accepts enum names as valid inputs. Previously, only enum values were accepted, so `--color=RED` would fail when the value was `"red"`. Both names and values are now accepted.
```python
import enum
import flyte
class Color(enum.Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
@flyte.task
def example_task(color: Color):
return f"Selected color is {color.name}"
```
### :wrench: Enhanced Pod Template Handling
Pod templates are now properly maintained across task overrides. Previously, overriding certain task attributes could inadvertently discard custom pod templates. Pod specifications, labels, and annotations now persist even after renaming tasks or modifying other properties.
### :zap: Stress Testing Example Added
A new stress testing example demonstrates a fan-out execution pattern, creating a dynamic tree of asynchronous tasks to simulate high concurrency. You can control the number of tasks spawned at each layer and introduce variability with a jitter parameter.
### :bug: Correct Serialization Field
Fixed a bug in the serialization of scaling metrics: the correct field `target_value` is now used instead of `val`. This ensures proper serialization for `Scaling.Concurrency` and `Scaling.RequestRate` metrics as expected by the protobuf definitions.
### :wrench: Improved Async Task Handling
Async Flyte tasks now route execution through `task.aio()`, ensuring consistent invocation through Flyte's controller and correct handling of nested async tasks.
### :wrench: Sync Alignment of File Upload Methods
`File.from_local_sync` and `File.from_local` now handle filenames consistently when uploading to remote storage. Previously, the sync and async methods could produce different filenames for the same upload.
```python
# Example of uploading a file with consistent naming:
import flyte
with tempfile.TemporaryDirectory() as temp_dir:
local_path = os.path.join(temp_dir, "source.txt")
remote_path = os.path.join(temp_dir, "destination.txt")
# Ensure the file content
with open(local_path, "w") as f:
f.write("sample content")
# Upload the local file to a remote location
uploaded_file = File.from_local_sync(local_path, remote_path)
print(f"Uploaded file path: {uploaded_file.path}")
```
### :hourglass: Request Timeout Configuration
You can now configure request timeouts for Flyte applications using the new `Timeouts` dataclass. Set a `request` timeout (as an integer or `timedelta`) to limit the maximum duration a request can take within an application environment.
### :wrench: Enhanced Bundling and Error Handling
Flyte now ignores `.git` directories in deployment code bundles, reducing artifact size and improving deployment speed. Additionally, explicit error handling for the `copy_style` parameter provides clear guidance when bundling is unnecessary.
### :wrench: Dynamic Pydantic Model Creation
The new `PydanticTransformer.guess_python_type` method dynamically creates Pydantic models from JSON schema metadata. This handles cases where the original Pydantic model class isn't available, enabling flexible deserialization of complex nested structures.
### :busts_in_silhouette: Human-in-the-Loop Plugin
The new Human-in-the-Loop (HITL) plugin enables workflows to pause and wait for human input via a web interface or programmatically. Create events that prompt for human interaction through an auto-served FastAPI app.
```python
import flyteplugins.hitl as hitl
# Create event and wait for human input
event = await hitl.new_event.aio(
"input_event",
data_type=int,
scope="run",
prompt="Enter a number"
)
value = await event.wait.aio()
```
### :rocket: Stateless Code Sandbox
Flyte now supports running arbitrary Python code and shell commands in an isolated, stateless Docker container with the `flyte.sandbox.create()` API. Three execution modes are available: Auto-IO, Verbatim, and Command, each handling inputs and outputs differently while running code in fresh, ephemeral containers.
### :wrench: Improved CLI Logging Initialization
The Flyte SDK now ensures a consistent logging setup when using the CLI. Previously, CLI commands would initialize configuration multiple times, leading to duplicated log entries. Now:
- Initialization occurs once per command execution.
- `RichHandler` is enabled from the start, so all logs display in rich format.
- The `hello.py` example script now has a default value, so it runs without arguments.
```python
@env.task
def main(x_list: list[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) -> float:
x_len = len(x_list)
if x_len < 10:
raise ValueError(f"x_list doesn't have a larger enough sample size, found: {x_len}")
y_list = list(flyte.map(fn, x_list))
y_mean = sum(y_list) / len(y_list)
return y_mean
```
### :wrench: Enhanced Ignore Handling
Flyte SDK now skips processing of `.gitignore` and `.flyteignore` files inside commonly ignored directories such as `.venv` or `__pycache__`, avoiding redundant file processing.
### :whale: CI Image Builder
A new example script automates Docker image building and pushing from CI. Configure it with your source and target image details to integrate with continuous deployment pipelines.
### :wrench: TypedDict Compatibility Fix
The Flyte SDK now correctly handles `TypedDict` for Python versions earlier than 3.12 by using `typing_extensions.TypedDict`.
```python
# Importing TypedDict based on Python version
import sys
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
```
### :globe_with_meridians: Cross-Platform Code Bundling
The Flyte SDK now uses POSIX-style paths for file hashing and tarball creation, ensuring consistent code bundling behavior across Windows and Unix systems.
### :wrench: Improved CLI JSON Formatting
The `flyte` CLI now uses the `to_dict()` method when available for JSON output, fixing `TypeError` failures that occurred with certain non-iterable object types.
### :wrench: Improved Pod Image Handling
Flyte now consistently merges container images when using a pod template. The primary container uses `app_env.image` if no explicit image is set, with correct handling of both `"auto"` and specific image values.
### :sparkles: Flyte Webhook Environment
A pre-built Flyte webhook environment makes it easier to integrate with FastAPI endpoints for common Flyte operations like running tasks, managing apps, and handling triggers. This update uses `httpx` for HTTP requests and expands endpoint exports for better customization.
### :repeat: Retry Interceptor for gRPC
A new retry interceptor for gRPC channels allows you to define how many times a gRPC call should be retried on transient failures. Specify the number of retry attempts using the `rpc_retries` option during channel creation.
### :sparkles: Orchestration Sandbox Feature
Flyte 2.0 now supports dynamic orchestration within a sandbox using `flyte.sandbox.orchestrator_from_str()`. Create reusable orchestration templates directly from Python code strings without defining decorated functions β useful when code is dynamically generated from UIs or language models.
### :wrench: Task Shortname Override Fix
You can now override the shortname for tasks in the UI by setting the `short_name` parameter in task overrides. Previously, overridden shortnames were not reflected in the Flyte UI.
### :sparkles: NVIDIA H100 GPU Support
Flyte now supports NVIDIA H100 GPUs with various MIG partitions for fine-grained resource allocation.
```python
from flyte import GPU, Resources
h100_mig_env = flyte.TaskEnvironment(
name="h100_mig",
resources=Resources(
cpu="1",
memory="4Gi",
gpu=GPU(device="H100", quantity=1, partition="1g.10gb"),
),
)
```
### :zap: Enhanced Error Handling in PyTorch Elastic Jobs
Flyte's PyTorch integration now includes configurable NCCL timeout settings to better manage CUDA out-of-memory (OOM) situations. This prevents elastic jobs from hanging due to OOM by introducing faster failure detection and customizable restart policies. You can reduce timeout durations, enable asynchronous error handling, and activate built-in monitoring.
### :wrench: Reverse Path Priority Fix
The Flyte SDK's handling of `sys.path` when running tasks remotely now respects local path priority. Previously, the `entrypoint` directory could override top-level packages. This fix ensures consistent path prioritization between local development and remote execution.
### :globe_with_meridians: S3 Virtual Hosted-Style Support
You can now specify the addressing style for S3-compatible backends by setting the `FLYTE_AWS_S3_ADDRESSING_STYLE` environment variable to `virtual`. This constructs URLs in the format `https://./`, enabling compatibility with more storage providers.
## November 2025
### :fast_forward: Grouped Runs
We redesigned the Runs page to better support large numbers of runs. Historically, large projects produced so many runs that flat listings became difficult to navigate. The new design groups Runs by their root task - leveraging the fact that while there may be millions of runs, there are typically only dozens or hundreds of deployed tasks. This grouped view, combined with enhanced filtering (by status, owner, duration, and more coming soon), makes it dramatically faster and easier to locate the exact runs users are looking for, even in the largest deployments.

### :globe_with_meridians: Apps (beta)
You can now deploy apps in Union 2.0. Apps let you host ML models, Streamlit dashboards, FastAPI services, and other interactive applications alongside your workflows. Simply define your app, deploy it, and Union will handle the infrastructure, routing, and lifecycle management. You can even call apps from your tasks to build end-to-end workflows that combine batch processing with real-time serving.
To create an app, import `flyte` and use either `FastAPIAppEnvironment` for FastAPI applications or the generic `AppEnvironment` for other frameworks. Here's a simple FastAPI example:
```python
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12))
.with_pip_packages("fastapi", "uvicorn"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
@env.app.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env) # Deploy and serve your app
```
For Streamlit apps, use the generic `AppEnvironment` with a command:
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-hello-v2",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1"),
command="streamlit hello --server.port 8080",
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
```
You can call apps from tasks by using `depends_on` and making HTTP requests to the app's endpoint. Please refer to the example in the [SDK repo](https://github.com/flyteorg/flyte-sdk/blob/main/examples/apps/call_apps_in_tasks/app.py). Similarly, you can call apps from other apps (see this [example](https://github.com/flyteorg/flyte-sdk/blob/main/examples/apps/app_calling_app/app.py)).
### :label: Custom context
You can now pass configuration and metadata implicitly through your entire task execution hierarchy using custom context. This is ideal for cross-cutting concerns like tracing IDs, experiment metadata, environment information, or logging correlation keysβdata that needs to be available everywhere but isn't logically part of your task's computation.
Custom context is a string key-value map that automatically flows from parent to child tasks without adding parameters to every function signature. Set it once at the run level with `with_runcontext()`, or override values within tasks using the `flyte.custom_context()` context manager:
```python
import flyte
env = flyte.TaskEnvironment("custom-context-example")
@env.task
async def leaf_task() -> str:
# Reads run-level context
print("leaf sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def root() -> str:
return await leaf_task()
if __name__ == "__main__":
flyte.init_from_config()
# Base context for the entire run
run = flyte.with_runcontext(custom_context={"trace_id": "root-abc", "experiment": "v1"}).run(root)
print(run.url)
```
### :lock: Secrets UI
Now you can view and create secrets directly from the UI. Secrets are stored securely in your configured secrets manager and injected into your task environments at runtime.

### Image builds now run in the same project-domain
The image build task is now executed within the same project and domain as the user task, rather than in system-production. This change improves isolation and is a key step toward supporting multi-dataplane clusters.
### Support for secret mounts in Poetry and UV projects
We added support for mounting secrets into both Poetry and UV-based projects. This enables secure access to private dependencies or credentials during image build.
```python
import pathlib
import flyte
env = flyte.TaskEnvironment(
name="uv_project_lib",
resources=flyte.Resources(memory="1000Mi"),
image=(
flyte.Image.from_debian_base().with_uv_project(
pyproject_file=pathlib.Path(__file__).parent / "pyproject.toml",
pre=True,
secret_mounts="my_secret",
)
),
)
```
## October 2025
### :infinity: Larger fanouts
You can now run up to 50,000 actions within a run and up to 1,000 actions concurrently.
To enable observability across so many actions, we added group and sub-actions UI views, which show summary statistics about the actions which were spawned within a group or action.
You can use these summary views (as well as the action status filter) to spot check long-running or failed actions.

### :computer: Remote debugging for Ray head nodes
Rather than locally reproducing errors, sometimes you just want to zoom into the remote execution and see what's happening.
We directly enable this with the debug button.
When you click "Debug action" from an action in a run, we spin up that action's environment, code, and input data, and attach a live VS Code debugger.
Previously, this was only possible with vanilla Python tasks.
Now, you can debug multi-node distributed computations on Ray directly.

### :zap: Triggers and audit history
**Configure tasks > Triggers** let you templatize and set schedules for your workflows, similar to Launch Plans in Flyte 1.0.
```python
@env.task(triggers=flyte.Trigger.hourly()) # Every hour
def example_task(trigger_time: datetime, x: int = 1) -> str:
return f"Task executed at {trigger_time.isoformat()} with x={x}"
```
Once you deploy, it's possible to see all the triggers which are associated with a task:

We also maintain an audit history of every deploy, activation, and deactivation event, so you can get a sense of who's touched an automation.

### :arrow_up: Deployed tasks and input passing
You can see the runs, task spec, and triggers associated with any deployed task, and launch it from the UI. We've converted the launch forms to a convenient JSON Schema syntax, so you can easily copy-paste the inputs from a previous run into a new run for any task.

=== PAGE: https://www.union.ai/docs/v2/flyte/deployment ===
# Platform deployment
The Union.ai platform uses a split-plane model with separate control and data planes.
In both BYOC and Self-managed deployments, your code, input and output data, container images and logs reside entirely on the **data plane**, which runs in your cloud account, while the **control plane** runs on Union.ai's cloud account, providing the workflow orchestration logic.
The **control plane** does not have access to the code, data, images, or logs in the **data plane**.
If you choose a **Self-managed deployment**, your data isolation is further enhanced by the fact that you manage your data plane entirely on your own, without providing any access to Union.ai customer support.
If you choose a **BYOC deployment**, Union.ai manages the Kubernetes cluster in your data plane for you. The data isolation of the control vs. data plane is still enforced - for example, Union.ai has no access to your object storage or logs. However, Union.ai customer support will have some access to your cluster, though strictly for upgrades, provisioning, and other actions related to maintaining cluster health.
## Subpages