=== PAGE: https://www.union.ai/docs/v2/flyte ===
# Documentation
Welcome to the documentation.
## Subpages
- **Flyte**
- **Tutorials**
- **Integrations**
- **Reference**
- **Community**
- **Release notes**
- **Platform deployment**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide ===
# Flyte
Flyte is a free and open source platform that provides a full suite of powerful features for orchestrating AI workflows.
Flyte empowers AI development teams to rapidly ship high-quality code to production by offering optimized performance, unparalleled resource efficiency, and a delightful workflow authoring experience.
You deploy and manage Flyte yourself, on your own cloud infrastructure.
> [!NOTE]
> These are the Flyte **2.0 beta** docs.
> To switch to [version 1.0](/docs/v1/flyte/) or to the commercial product, [**Union.ai**](/docs/v2/byoc/), use the selectors above.
>
> This documentation for open-source Flyte is maintained by Union.ai.
### π‘ **Flyte 2**
Flyte 2 represents a fundamental shift in how AI workflows are written and executed. Learn
more in this section.
### π’ **Getting started**
Install Flyte 2, configure your local IDE, create and run your first task, and inspect the results in 2 minutes.
## Subpages
- **Flyte 2**
- **Getting started**
- **Configure tasks**
- **Build tasks**
- **Run and deploy tasks**
- **Scale your runs**
- **Configure apps**
- **Build apps**
- **Serve and deploy apps**
- **Considerations**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2 ===
# Flyte 2
Flyte 2 represents a fundamental shift in how workflows are written and executed in Flyte.
> **π Ready to get started?**
>
> Ready to get started? Go the **Getting started** guide to install Flyte 2 and run your first task.
## Pure Python execution
Write workflows in pure Python, enabling a more natural development experience and removing the constraints of a
domain-specific language (DSL).
### Sync Python
```
import flyte
env = flyte.TaskEnvironment("sync_example_env")
@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
def main(name: str) -> str:
for i in range(10):
hello_world(name)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/sync_example.py)
### Async Python
```
import asyncio
import flyte
env = flyte.TaskEnvironment("async_example_env")
@env.task
async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
async def main(name: str) -> str:
results = []
for i in range(10):
results.append(hello_world(name))
await asyncio.gather(*results)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async_example.py)
As you can see in the hello world example, workflows can be constructed at runtime, allowing for more flexible and
adaptive behavior. The Flyte 2 also supports:
- Python's asynchronous programming model to express parallelism.
- Python's native error handling with `try-except` to overridden configurations, like resource requests.
- Predefined static workflows when compile-time safety is critical.
## Simplified API
The new API is more intuitive, with fewer abstractions to learn and a focus on simplicity.
| Use case | Flyte 1 | Flyte 2 |
| ----------------------------- | --------------------------- | --------------------------------------- |
| Environment management | `N/A` | `TaskEnvironment` |
| Perform basic computation | `@task` | `@env.task` |
| Combine tasks into a workflow | `@workflow` | `@env.task` |
| Create dynamic workflows | `@dynamic` | `@env.task` |
| Fanout parallelism | `flytekit.map` | Python `for` loop with `asyncio.gather` |
| Conditional execution | `flytekit.conditional` | Python `if-elif-else` |
| Catching workflow failures | `@workflow(on_failure=...)` | Python `try-except` |
There is no `@workflow` decorator. Instead, "workflows" are authored through a pattern of tasks calling tasks.
Tasks are defined within environments, which encapsulate the context and resources needed for execution.
## Fine-grained reproducibility and recoverability
Flyte tasks support caching via `@env.task(cache=...)`, but tracing with `@flyte.trace` augments task level-caching
even further enabling reproducibility and recovery at the sub-task function level.
```
import flyte
env = flyte.TaskEnvironment(name="trace_example_env")
@flyte.trace
async def call_llm(prompt: str) -> str:
return "Initial response from LLM"
@env.task
async def finalize_output(output: str) -> str:
return "Finalized output"
@env.task(cache=flyte.Cache(behavior="auto"))
async def main(prompt: str) -> str:
output = await call_llm(prompt)
output = await finalize_output(output)
return output
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, prompt="Prompt to LLM")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/trace.py)
Here the `call_llm` function is called in the same container as `main` that serves as an automated checkpoint with full
observability in the UI. If the task run fails, the workflow is able to recover and replay from where it left off.
## Improved remote functionality
Flyte 2 provides full management of the workflow lifecycle through a standardized API through the CLI and the Python SDK.
| Use case | CLI | Python SDK |
| ------------- | ------------------ | ------------------- |
| Run a task | `flyte run ...` | `flyte.run(...)` |
| Deploy a task | `flyte deploy ...` | `flyte.deploy(...)` |
You can also fetch and run remote (previously deployed) tasks within the course of a running workflow.
```
import flyte
from flyte import remote
env_1 = flyte.TaskEnvironment(name="env_1")
env_2 = flyte.TaskEnvironment(name="env_2")
env_1.add_dependency(env_2)
@env_2.task
async def remote_task(x: str) -> str:
return "Remote task processed: " + x
@env_1.task
async def main() -> str:
remote_task_ref = remote.Task.get("env_2.remote_task", auto_version="latest")
r = await remote_task_ref(x="Hello")
return "main called remote and recieved: " + r
if __name__ == "__main__":
flyte.init_from_config()
d = flyte.deploy(env_1)
print(d[0].summary_repr())
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/remote.py)
## Native Notebook support
Author and run workflows and fetch workflow metadata (I/O and logs) directly from Jupyter notebooks.

## Enhanced UI
New UI with a streamlined and user-friendly experience for authoring and managing workflows.

This UI improves the visualization of workflow execution and monitoring, simplifying access to logs, metadata, and other important information.
## Subpages
- **Flyte 2 > Pure Python**
- **Flyte 2 > Asynchronous model**
- **Flyte 2 > Migration from Flyte 1 to Flyte 2**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/pure-python ===
# Pure Python
Flyte 2 introduces a new way of writing workflows that is based on pure Python, removing the constraints of a domain-specific language (DSL) and enabling full use of Python's capabilities.
## From `@workflow` DSL to pure Python
| Flyte 1 | Flyte 2 |
| --- | --- |
| `@workflow`-decorated functions are constrained to a subset of Python for defining a static directed acyclic graph (DAG) of tasks. | **No more `@workflow` decorator**: Everything is a `@env.task`, so your top-level βworkflowβ is simply a task that calls other tasks. |
| `@task`-decorated functions could leverage the full power of Python, but only within individual container executions. | `@env.task`s can call other `@env.task`s and be used to construct workflows with dynamic structures using loops, conditionals, try/except, and any Python construct anywhere. |
| Workflows were compiled into static DAGs at registration time, with tasks as the nodes and the DSL defining the structure. | Workflows are simply tasks that call other tasks. Compile-time safety will be available in the future as `compiled_task`. |
### Flyte 1
```python
import flytekit
image = flytekit.ImageSpec(
name="hello-world-image",
packages=["requests"],
)
@flytekit.task(container_image=image)
def mean(data: list[float]) -> float:
return sum(list) / len(list)
@flytekit.workflow
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is not allowed
# output = output / 100
# β if/else is not allowed
# if output < 0:
# raise ValueError("Output cannot be negative")
return output
```
### Flyte 2
```
import flyte
env = flyte.TaskEnvironment(
"hello_world",
image=flyte.Image.from_debian_base().with_pip_packages("requests"),
)
@env.task
def mean(data: list[float]) -> float:
return sum(data) / len(data)
@env.task
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is allowed
output = output / 100
# β if/else is allowed
if output < 0:
raise ValueError("Output cannot be negative")
return output
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/pure-python/flyte_2.py)
These fundamental changes bring several transformative benefits:
- **Flexibility**: Harness the complete Python language for workflow definition, including all control flow constructs previously forbidden in workflows.
- **Dynamic workflows**: Create workflows that adapt to runtime conditions, handle variable data structures, and make decisions based on intermediate results.
- **Natural error handling**: Use standard Python `try`/`except` patterns throughout your workflows, making them more robust and easier to debug.
- **Intuitive composability**: Build complex workflows by naturally composing Python functions, following familiar patterns that any Python developer understands.
## Workflows can still be static when needed
> [!NOTE]
> This feature is coming soon.
The flexibility of dynamic workflows is absolutely needed for many use cases, but there are other scenarios where static workflows are beneficial. For these cases, Flyte 2 will offer compilation of the top-level task of a workflow into a static DAG.
This upcoming feature will provide:
- **Static analysis**: Enable workflow visualization and validation before execution
- **Predictable resources**: Allow precise resource planning and scheduling optimization
- **Traditional tooling**: Support existing DAG-based analysis and monitoring tools
- **Hybrid approach**: Choose between dynamic and static execution based on workflow characteristics
The static compilation system will naturally have limitations compared to fully dynamic workflows:
- **Dynamic fanouts**: Constructs that require runtime data to reify, for example, loops with an iteration-size that depends on intermediate results, will not be compilable.
- However, constructs whose size and scope *can* be determined at registration time, such as fixed-size loops or maps, *will* be compilable.
- **Conditional branching**: Decision trees whose size and structure depend on intermediate results will not be compilable.
- However, conditionals with fixed branch size will be compilable.
For the applications that require a predefined workflow graph, Flyte 2 will enable compilability up to the limits implicit in directed acyclic graphs.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/async ===
# Asynchronous model
## Why we need an async model
The shift to an asynchronous model in Flyte 2 is driven by the need for more efficient and flexible workflow execution.
We believe, in particular, that with the rise of the agentic AI pattern, asynchronous programming has become an essential part of AI/ML engineering and data science toolkit.
With Flyte 2, the entire framework is now written with async constructs, allowing for:
- Seamless overlapping of I/O and independent external operations.
- Composing multiple tasks and external tool invocations within the same Python process.
- Native support of streaming operations for data, observability and downstream invocations.
It is also a natural fit for the expression parallelism in workflows.
### Understanding concurrency vs. parallelism
Before diving into Flyte 2's approach, it's essential to understand the distinction between concurrency and parallelism:
| Concurrency | Parallelism |
| --- | --- |
| Dealing with multiple tasks at once through interleaved execution, even on a single thread. | Executing multiple tasks truly simultaneously across multiple cores or machines. |
| Performance benefits come from allowing the system to switch between tasks when one is waiting for external operations. | This is a subset of concurrency where tasks run at the same time rather than being interleaved. |
### Python's async evolution
Python's asynchronous programming capabilities have evolved significantly:
- **The GIL challenge**: Python's Global Interpreter Lock (GIL) traditionally prevented true parallelism for CPU-bound tasks, limiting threading effectiveness to I/O-bound operations.
- **Traditional solutions**:
- `multiprocessing`: Created separate processes to sidestep the GIL, effective but resource-intensive
- `threading`: Useful for I/O-bound tasks where the GIL could be released during external operations
- **The async revolution**: The `asyncio` library introduced cooperative multitasking within a single thread, using an event loop to manage multiple tasks efficiently.
### Parallelism in Flyte 1 vs Flyte 2
| | Flyte 1 | Flyte 2 |
| --- | --- | --- |
| Parallelism | The workflow DSL automatically parallelized tasks that weren't dependent on each other. The `map` operator allowed running a task multiple times in parallel with different inputs. | Leverages Python's `asyncio` as the primary mechanism for expressing parallelism, but with a crucial difference: **the Flyte orchestrator acts as the event loop**, managing task execution across distributed infrastructure. |
### Core async concepts
- **`async def`**: Declares a function as a coroutine. When called, it returns a coroutine object managed by the event loop rather than executing immediately.
- **`await`**: Pauses coroutine execution and passes control back to the event loop.
In standard Python, this enables other tasks to run while waiting for I/O operations.
In Flyte 2, it signals where tasks can be executed in parallel.
- **`asyncio.gather`**: The primary tool for concurrent execution.
In standard Python, it schedules multiple awaitable objects to run concurrently within a single event loop.
In Flyte 2, it signals to the orchestrator that these tasks can be distributed across separate compute resources.
#### A practical example
Consider this pattern for parallel data processing:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("data_pipeline")
@env.task
async def process_chunk(chunk_id: int, data: str) -> str:
# This could be any computational work - CPU or I/O bound
await asyncio.sleep(1) # Simulating work
return f"Processed chunk {chunk_id}: {data}"
@env.task
async def parallel_pipeline(data_chunks: list[str]) -> list[str]:
# Create coroutines for all chunks
tasks = []
for i, chunk in enumerate(data_chunks):
tasks.append(process_chunk(i, chunk))
# Execute all chunks in parallel
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
In standard Python, this would provide concurrency benefits primarily for I/O-bound operations.
In Flyte 2, the orchestrator schedules each `process_chunk` task on separate Kubernetes pods or configured plugins, achieving true parallelism for any type of work.
### True parallelism for all workloads
This is where Flyte 2's approach becomes revolutionary: **async syntax is not just for I/O-bound operations**.
The `async`/`await` syntax becomes a powerful way to declare your workflow's parallel structure for any type of computation.
When Flyte's orchestrator encounters `await asyncio.gather(...)`, it understands that these tasks are independent and can be executed simultaneously across different compute resources.
This means you achieve true parallelism for:
- **CPU-bound computations**: Heavy mathematical operations, model training, data transformations
- **I/O-bound operations**: Database queries, API calls, file operations
- **Mixed workloads**: Any combination of computational and I/O tasks
The Flyte platform handles the complex orchestration while you express parallelism using intuitive `async` syntax.
## Bridging the transition: Sync support and migration tools
### Seamless synchronous task support
Recognizing that many existing codebases use synchronous functions, Flyte 2 provides seamless backward compatibility:
```
@env.task
def legacy_computation(x: int) -> int:
# Existing synchronous function works unchanged
return x * x + 2 * x + 1
@env.task
async def modern_workflow(numbers: list[int]) -> list[int]:
# Call sync tasks from async context using .aio()
tasks = []
for num in numbers:
tasks.append(legacy_computation.aio(num))
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
Under the hood, Flyte automatically "asyncifies" synchronous functions, wrapping them to participate seamlessly in the async execution model.
You don't need to rewrite existing codeβjust leverage the `.aio()` method when calling sync tasks from async contexts.
### The `flyte.map` function: Familiar patterns
For scenarios that previously used Flyte 1's `map` operation, Flyte 2 provides `flyte.map` as a direct replacement.
The new `flyte.map` can be used either in synchronous or asynchronous contexts, allowing you to express parallelism without changing your existing patterns.
### Sync Map
```
@env.task
def sync_map_example(n: int) -> list[str]:
# Synchronous version for easier migration
results = []
for result in flyte.map(process_item, range(n)):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
### Async Map
```
@env.task
async def async_map_example(n: int) -> list[str]:
# Async version using flyte.map - exact pattern from SDK examples
results = []
async for result in flyte.map.aio(process_item, range(n), return_exceptions=True):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
The `flyte.map` function provides:
- **Dual interfaces**: `flyte.map.aio()` for async contexts, `flyte.map()` for sync contexts.
- **Built-in error handling**: `return_exceptions` parameter for graceful failure handling. This matches the `asyncio.gather` interface,
allowing you to decide how to handle errors.
If you are coming from Flyte 1, it allows you to replace `min_success_ratio` in a more flexible way.
- **Automatic UI grouping**: Creates logical groups for better workflow visualization.
- **Concurrency control**: Optional limits for resource management.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/flyte-2/migration ===
# Migration from Flyte 1 to Flyte 2
> [!NOTE]
> Automated migration from Flyte 1 to Flyte 2 is coming soon.
Flyte 2 will soon offer automated migration from Flyte 1 to 2.
In the meantime you can migrate manually by following the steps below.:
### 1. Move task configuration to a `TaskEnvironment` object
Instead of configuring the image, hardware resources, and so forth, directly in the task decorator. You configure it in `TaskEnvironment` object. For example:
```python
env = flyte.TaskEnvironment(name="my_task_env")
```
### 2. Replace workflow decorators
Then, you replace the `@workflow` and `@task` decorators with `@env.task` decorators.
### Flyte 1
Here's a simple hello world example with fan-out.
```python
import flytekit
@flytekit.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@flytekit.workflow
def main(names: list[str]) -> list[str]:
return flytekit.map(hello_world)(names)
```
### Flyte 2 Sync
Change all the decorators to `@env.task` and swap out `flytekit.map` with `flyte.map`.
Notice that `flyte.map` is a drop-in replacement for Python's built-in `map` function.
```diff
-@flytekit.task
+@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
-@flytekit.workflow
+@env.task
def main(names: list[str]) -> list[str]:
return flyte.map(hello_world, names)
```
> **π Note**
>
> Note that the reason our task decorator uses `env` is simply because that is the variable to which we assigned the `TaskEnvironment` above.
### Flyte 2 Async
To take advantage of full concurrency (not just parallelism), use Python async
syntax and the `asyncio` standard library to implement fa-out.
```diff
+import asyncio
@env.task
-def hello_world(name: str) -> str:
+async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
-def main(names: list[str]) -> list[str]:
+async def main(names: list[str]) -> list[str]:
- return flyte.map(hello_world, names)
+ return await asyncio.gather(*[hello_world(name) for name in names])
```
> **π Note**
>
> To use Python async syntax, you need to:
> - Use `asyncio.gather()` or `flyte.map()` for parallel execution
> - Add `async`/`await` keywords where you want parallelism
> - Keep existing sync task functions unchanged
>
> Learn more about about the benefits of async in the **Flyte 2 > Asynchronous model** guide.
### 3. Leverage enhanced capabilities
- Add conditional logic and loops within workflows
- Implement proper error handling with try/except
- Create dynamic workflows that adapt to runtime conditions
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/getting-started ===
# Getting started
This section gives you a quick introduction to writing and running workflows on Union and Flyte 2.
## Prerequisites
You will need the following:
- An active Python virtual environment with Python 3.10 or later.
- The URL of you Union/Flyte instance.
- An existing project set up on your Union/Flyte instance where you have permission to run workflows.
## Install the `flyte` package
Install the latest `flyte` package in the virtual environment (we are currently in beta, so you will have to enable prerelease installation). For example:
```shell
pip install --pre flyte
```
Check that installation succeeded (and that you have activated your virtual environment):
```shell
flyte --version
```
## Create a config.yaml
Next, create a configuration file that points to your Flyte instance.
Use the **Flyte CLI > flyte > flyte create > flyte create config** command, making the following changes:
- Replace `my-org.my-company.com` with the actual URL of your Flyte backend instance.
You can simply copy the domain part of the URL from your browser when logged into your backend instance.
- Replace `my-project` with an actual project.
The project you specify must already exist on your Flyte backend instance.
```shell
flyte create config \
--endpoint my-org.my-company.com \
--builder local \
--domain development \
--project my-project
```
### Ensure local Docker is working
> [!NOTE]
> We are using the `--builder local` option here to specify that we want to **Configure tasks > Container images** locally.
> If you were using a Union instance, you would typically use `--builder remote` instead to use Union's remote image builder.
> With Flyte OSS instances, `local` is the only option available.
To enable local image building, ensure that
- You have Docker installed and running on your machine
- You have permission to read from the public GitHub `ghcr.io` registry.
- You have successfully logged into the `ghcr.io` registry using Docker:
```shell
docker login ghcr.io
```
By default, this will create a `./.flyte/config.yaml` file in your current working directory.
See **Getting started > Local setup > Setting up a configuration file** for details.
> **π Note**
>
> Run `flyte get config` to see the current configuration file being used by the `flyte` CLI.
## Hello world example
Create a file called `hello.py` with the following content:
```python
# hello.py
import flyte
# The `hello_env` TaskEnvironment is assgned to the variable `env`.
# It is then used in the `@env.task` decorator to define tasks.
# The environment groups configuration for all tasks defined within it.
env = flyte.TaskEnvironment(name="hello_env")
# We use the `@env.task` decorator to define a task called `fn`.
@env.task
def fn(x: int) -> int: # Type annotations are required
slope, intercept = 2, 5
return slope * x + intercept
# We also use the `@env.task` decorator to define another task called `main`.
# This is the is the entrypoint task of the workflow.
# It calls the `fn` task defined above multiple times using `flyte.map`.
@env.task
def main(x_list: list[int] = list(range(10))) -> float:
y_list = list(flyte.map(fn, x_list)) # flyte.map is like Python map, but runs in parallel.
y_mean = sum(y_list) / len(y_list)
return y_mean
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/hello.py)
## Understanding the code
In the code above we do the following:
- Import the `flyte` package.
- Define a `TaskEnvironment` to group the configuration used by tasks.
- Define two tasks using the `@env.task` decorator.
- Tasks are regular Python functions, but each runs in its own container.
- When deployed to your Union/Flyte instance, each task execution will run in its own separate container.
- Both tasks use the same `env` (the same `TaskEnvironment`) so, while each runs in its own container, those containers will be configured identically.
## Running the code
Assuming that your current directory looks like this:
```
.
βββ hello.py
βββ .flyte
βββ config.yaml
```
and your virtual environment is activated, you can run the script with:
```shell
flyte run hello.py main
```
This will package up the code and send it to your Flyte/Union instance for execution.
## Viewing the results
In your terminal, you should see output like this:
```shell
cg9s54pksbjsdxlz2gmc
https://my-instance.example.com/v2/runs/project/my-project/domain/development/cg9s54pksbjsdxlz2gmc
Run 'a0' completed successfully.
```
Click the link to go to your Flyte/Union instance and see the run in the UI:

## Subpages
- **Getting started > Local setup**
- **Getting started > Running tasks**
- **Getting started > Serving apps**
- **Getting started > Basic concepts**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/getting-started/local-setup ===
# Local setup
In this section we will explain the options for configuring the `flyte` CLI and SDK to connect to your Union/Flyte instance.
Before proceeding, make sure you have completed the steps in **Getting started**.
You will need to have the `uv` tool and the `flyte` Python package installed.
## Setting up a configuration file
In **Getting started** we used the `flyte create config` command to create a configuration file at `./.flyte/config.yaml`.
```shell
flyte create config \
--endpoint my-org.my-company.com \
--project my-project \
--domain development \
--builder local
```
This command creates a file called `./flyte/config.yaml` in your current working directory:
```yaml
admin:
endpoint: dns:///my-org.my-company.com
image:
builder: local
task:
domain: development
org: my-org
project: my-project
```
π‘ See full example using all available options
The example below creates a configuration file called `my-config.yaml` in the current working directory.
```shell
flyte create config \
--endpoint my-org.my-company.com \
--insecure \
--builder local \
--domain development \
--org my-org \
--project my-project \
--output my-config.yaml \
--force
```
See the **Flyte CLI > flyte > flyte create > flyte create config** section for details on the available parameters.
βΉοΈ Notes about the properties in the config file
**`admin` section**: contains the connection details for your Union/Flyte instance.
* `admin.endpoint` is the URL (always with `dns:///` prefix) of your Union/Flyte instance.
If your instance UI is found at https://my-org.my-company.com, the actual endpoint used in this file would be `dns:///my-org.my-company.com`.
* `admin.insecure` indicates whether to use an insecure connection (without TLS) to the Union/Flyte instance.
A setting of `true` is almost always only used for connecting to a local instance on your own machine.
**`image` section**: contains the configuration for building Docker images for your tasks.
* `image.builder` specifies the image builder to use for building Docker images for your tasks.
* For Union instances this is usually set to `remote`, which means that the images will be built on Union's infrastructure using the Union `ImageBuilder`.
* For Flyte OSS instances, `ImageBuilder` is not available, so this property must be set to `local`.
This means that the images will be built locally on your machine.
You need to have Docker installed and running for this to work.
See **Configure tasks > Container images > Image building** for details.
**`task` section**: contains the configuration for running tasks on your Union/Flyte instance.
* `task.domain` specifies the domain in which the tasks will run.
Domains are used to separate different environments, such as `development`, `staging`, and `production`.
* `task.org` specifies the organization in which the tasks will run. The organization is usually synonymous with the name of the Union instance you are using, which is usually the same as the first part of the `admin.endpoint` URL.
* `task.project` specifies the project in which the tasks will run. The project you specify here will be the default project to which tasks are deployed if no other project is specified. The project you specify must already exist on your Union/Flyte instance (it will not be auto-created on first deploy).
## Using the configuration file
You can use the configuration file either explicitly by referencing it directly from a CLI or Python command, or implicitly by placing it in a specific location or setting an environment variable.
### Specify a configuration file explicitly
When using the `flyte` CLI, you can specify the configuration file explicitly by using the `--config` or `-c` parameter.
You can explicitly specify the configuration file when running a `flyte` CLI command by using the `--config` parameter, like this:
```shell
flyte --config my-config.yaml run hello.py main
```
or just using the `-c` shorthand:
```shell
flyte -c my-config.yaml run hello.py main
```
When invoking flyte commands programmatically, you have to first initialize the Flyte SDK with the configuration file.
To initialize with an explicitly specified configuration file, use **Getting started > Local setup > `flyte.init_from_config`**:
```python
flyte.init_from_config("my-config.yaml")
```
Then you can continue with other `flyte` commands, such as running the main task:
```python
run = flyte.run(main)
```
### Use the configuration file implicitly
You can also use the configuration file implicitly by placing it in a specific location or setting an environment variable.
You can use the `flyte CLI` without an explicit `--config` like this:
```shell
flyte run hello.py main
```
You can also initializing the Flyte SDK programmatically without specifying a configuration file, like this:
```python
flyte.init_from_config()
```
In these cases, the SDK will search in the following order until it finds a configuration file:
* `./config.yaml` (i.e., in the current working directory).
* `./flyte/config.yaml` (i.e., in the `.flyte` directory in the current working directory).
* `UCTL_CONFIG` (a file pointed to by this environment variable).
* `FLYTECTL_CONFIG` (a file pointed to by this environment variable)
* `~/.union/config.yaml`
* `~/.flyte/config.yaml`
### Checking your configuration
You can check your current configuration by running the following command:
```shell
flyte get config
```
This will return the current configuration as a serialized Python object. For example
```shell
CLIConfig(
Config(
platform=PlatformConfig(endpoint='dns:///my-org.my-company.com', scopes=[]),
task=TaskConfig(org='my-org', project='my-project', domain='development'),
source=PosixPath('/Users/me/.flyte/config.yaml')
),
,
log_level=None,
insecure=None
)
```
## Inline configuration
### With `flyte` CLI
You can also use Flyte SDK with inline configuration parameters, without using a configuration file.
When using the `flyte` CLI, some parameters are specified after the top level command (i.e., `flyte`) while other are specified after the sub-command (for example, `run`).
For example, you can run a workflow using the following command:
```shell
flyte \
--endpoint my-org.my-company.com \
--org my-org \
run \
--domain development \
--project my-project
hello.py \
main
```
See the **Flyte CLI** for details.
When using the Flyte SDK programmatically, you can use the **Flyte SDK > Packages > flyte > Methods > init()** function to specify the backend endpoint and other parameters directly in your code.
### With `flyte` SDK
To initialize the Flyte SDK with inline parameters, you can use the **Flyte SDK > Packages > flyte > Methods > init()** function like this:
```python
flyte.init(
endpoint="dns:///my-org.my-company.com",
org="my-org",
project="my-project",
domain="development",
)
```
See the **Flyte SDK > Packages > flyte > Methods > init()** for details.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/getting-started/running ===
# Running tasks
Flyte SDK lets you seamlessly switch between running your workflows locally on your machine and running them remotely on your Union/Flyte instance.
Furthermore, you perform these actions either programmatically from within Python code or from the command line using the `flyte` CLI.
## Running remotely
### From the command-line
To run your code on your Union/Flyte instance, you can use the `flyte run` command without the `--local` flag:
```shell
flyte run hello.py main
```
This deploys your code to the configured Union/Flyte instance and runs it immediately (Since no explicit `--config` is specified, the configuration found according to the **Getting started > Local setup > Using the configuration file > Use the configuration file implicitly** will be used).
### From Python
To run your workflow remotely from Python, use **Flyte SDK > Packages > flyte > Methods > run()** by itself, like this:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "name='World'"
# ///
# run_from_python.py
# {{docs-fragment all}}
import flyte
env = flyte.TaskEnvironment(name="hello_world")
@env.task
def main(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment all}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/running/run_from_python.py)
This is the approach we use throughout our examples in this guide.
We execute the script, thus invoking the `flyte.run()` function, with the top-level task as a parameter.
The `flyte.run()` function then deploys and runs the code in that file itself on your remote Union/Flyte instance.
## Running locally
### From the command-line
To run your code on your local machine, you can use the `flyte run` command with the `--local` flag:
```shell
flyte run --local hello.py main
```
### From Python
To run your workflow locally from Python, you chain **Getting started > Running tasks > `flyte.with_runcontext()`** with **Flyte SDK > Packages > flyte > Methods > run()** and specify the run `mode="local"`, like this:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "name='World'"
# ///
# run_local_from_python.py
# {{docs-fragment all}}
import flyte
env = flyte.TaskEnvironment(name="hello_world")
@env.task
def main(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(mode="local").run(main, name="World")
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment all}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/running/run_local_from_python.py)
Running your workflow locally is useful for testing and debugging, as it allows you to run your code without deploying it to a remote instance.
It also lets you quickly iterate on your code without the overhead of deployment.
Obviously, if your code relies on remote resources or services, you will need to mock those in your local environment, or temporarily work around any missing functionality.
At the very least, local execution can be used to catch immediate syntax errors and other relatively simple issues before deploying your code to a remote instance.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/getting-started/serving-apps ===
# Serving apps
Flyte SDK lets you serve apps on your Union/Flyte instance, making them accessible via HTTP endpoints. Apps are long-running services that can be accessed by users or other services.
> [!TIP] Prerequisites
> Make sure to run the **Getting started > Local setup** before going through this guide.
First install FastAPI in your virtual environment:
```shell
pip install fastapi
```
## Hello world example
Create a file called `hello_app.py` with the following content:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "fastapi",
# "uvicorn",
# ]
# ///
"""A simple "Hello World" FastAPI app example for serving."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# Define a simple FastAPI application
app = FastAPI(
title="Hello World API",
description="A simple FastAPI application",
version="1.0.0",
)
# Create an AppEnvironment for the FastAPI app
env = FastAPIAppEnvironment(
name="hello-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# Define API endpoints
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Serving this script will deploy and serve the app on your Union/Flyte instance.
if __name__ == "__main__":
# Initialize Flyte from a config file.
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
# Serve the app remotely.
app_instance = flyte.serve(env)
# Print the app URL.
print(app_instance.url)
print("App 'hello-app' is now serving.")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/serving/hello_app.py)
## Understanding the code
In the code above we do the following:
- Import the `flyte` package and `FastAPIAppEnvironment` from `flyte.app.extras`.
- Define a FastAPI application using the `FastAPI` class.
- Create an `AppEnvironment` using `FastAPIAppEnvironment`:
- Apps are long-running services, unlike tasks which run to completion.
- The `FastAPIAppEnvironment` automatically configures the app to run with uvicorn.
- We specify the container image with required dependencies (FastAPI and uvicorn).
- We set resource limits (CPU and memory).
- We disable authentication for this example (`requires_auth=False`) so you can easily access the app with a `curl` command.
## Serving the app
Make sure that your `config.yaml` file is in the same directory as your `hello_app.py` script.
Now, serve the app with:
```shell
flyte serve hello_app.py env
```
You can also serve it via `python`:
```shell
python hello_app.py
```
This will use the code in the `if __name__ == "__main__":` block to serve the app
with the `flyte.serve()` function.
You can also serve the app using `python hello_app.py`, which
uses the main guard section in the script. It invokes `flyte.init_from_config()` to set up the connection with your Union/Flyte instance and `flyte.serve()` to deploy and serve your app on that instance.
> [!NOTE]
> The example scripts in this guide have a main guard that programmatically serves the apps defined in the same file.
> All you have to do is execute the script itself.
> You can also serve apps using the `flyte serve` CLI command. We will cover this in a later section.
## Viewing the results
In your terminal, you should see output like this:
```shell
https://my-instance.flyte.com/v2/domain/development/project/flytesnacks/apps/hello-app
App 'hello-app' is now serving.
```
Click the link to go to your Union instance and see the app in the UI, where you can find
the app URL, or visit `/docs` for the interactive Swagger UI API documentation.
## Next steps
Now that you've served your first app, you can learn more about:
- **Configure apps**: Learn how to configure app environments, including images, resources, ports, and more
- **Build apps**: Explore different types of apps you can build (FastAPI, Streamlit, vLLM, SGLang)
- **Serve and deploy apps**: Understand the difference between serving (development) and deploying (production) apps
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/getting-started/basic-concepts ===
# Basic concepts
To understand how Flyte 2 works, it helps to establish a few definitions and concepts.
* **Workflow**: A collection of tasks linked by invocation, with a top-most task that is the entry point of the workflow.
We sometime refer to this as the "parent", "driver" or "top-most" task.
Unlike in Flyte 1, there is no explicit `@workflow` decorator; instead, the workflow is defined implicitly by the structure of the code.
Nonetheless, you will often see the assemblage of tasks referred to as a "workflow".
* `TaskEnvironment`: A `[[TaskEnvironment]]` object is the abstraction that defines the hardware and software environment in which one or more tasks are executed.
* The hardware environment is specified by parameters that define the type of compute resources (e.g., CPU, memory) allocated to the task.
* The software environment is specified by parameters that define the container image, including dependencies, required to run the task.
* **Task**: A Python function.
* Tasks are defined using the `[[TaskEnvironment.task]]` decorator.
* Tasks can involve invoking helper functions as well as other tasks and assembling outputs from those invocations.
* **Run**: A `[[Run]]` is the execution of a task directly initiated by a user and all its descendant tasks, considered together.
* **Action**: An `[[Action]]` is the execution of a single task, considered independently. A run consists of one or more actions.
* **AppEnvironment**: An `[[AppEnvironment]]` object is the abstraction that defines the hardware and software environment in which an app runs.
* The hardware environment is specified by parameters that define the type of compute resources (e.g., CPU, memory, GPU) allocated to the app.
* The software environment is specified by parameters that define the container image, including dependencies, required to run the app.
* Apps have additional configuration options specific to services, such as port configuration, scaling behavior, and domain settings.
* **App**: A long-running service that provides functionality via HTTP endpoints. Unlike tasks, which run to completion, apps remain active and can handle multiple requests over time.
* **App vs Task**: The fundamental difference is that apps are services that stay running and handle requests, while tasks are functions that execute once and complete.
- Apps are suited for short running API calls that need low latency and durability is not required.
- Apps may expose one or more endpoints, which Tasks consist of one function entrypoint.
- Every invocation of a Task is durable and can run for long periods of time.
- In Flyte, durability means that inputs and outputs are recorded in an object store, are visible in the UI, can be cached. In multi-step tasks, durability provides the ability to resume the execution from where it left off without re-computing the output of a task
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration ===
# Configure tasks
As we saw in **Getting started**, you can run any Python function as a task in Flyte just by decorating it with `@env.task`.
This allows you to run your Python code in a distributed manner, with each function running in its own container.
Flyte manages the spinning up of the containers, the execution of the code, and the passing of data between the tasks.
The simplest possible case is a `TaskEnvironment` with only a `name` parameter, and an `env.task` decorator, with no parameters:
```
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name:str) -> str:
return f"Hello {name}!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
> [!NOTE]
> Notice how the `TaskEnvironment` is assigned to the variable `env` and then that variable is
> used in the `@env.task`. This is what connects the `TaskEnvironment` to the task definition.
>
> In the following we will often use `@env.task` generically to refer to the decorator,
> but it is important to remember that it is actually a decorator attached to a specific
> `TaskEnvironment` object, and the `env` part can be any variable name you like.
This will run your task in the default container environment with default settings.
But, of course, one of the key advantages of Flyte is the ability to control the software environment, hardware environment, and other execution parameters for each task, right in your Python code.
In this section we will explore the various configuration options available for tasks in Flyte.
## Task configuration levels
Task configuration is done at three levels. From most general to most specific, they are:
* The `TaskEnvironment` level: setting parameters when defining the `TaskEnvironment` object.
* The `@env.task` decorator level: Setting parameters in the `@env.task` decorator when defining a task function.
* The task invocation level: Using the `task.override()` method when invoking task execution.
Each level has its own set of parameters, and some parameters are shared across levels.
For shared parameters, the more specific level will override the more general one.
### Example
Here is an example of how these levels work together, showing each level with all available parameters:
```
# Level 1: TaskEnvironment - Base configuration
env_2 = flyte.TaskEnvironment(
name="data_processing_env",
image=flyte.Image.from_debian_base(),
resources=flyte.Resources(cpu=1, memory="512Mi"),
env_vars={"MY_VAR": "value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY"),
cache="disable",
# pod_template=my_pod_template,
# reusable=flyte.ReusePolicy(replicas=2, idle_ttl=300),
depends_on=[another_env],
description="Data processing task environment",
# plugin_config=my_plugin_config
)
# Level 2: Decorator - Override some environment settings
@env_2.task(
short_name="process",
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_2"),
cache="auto",
# pod_template=my_pod_template,
report=True,
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60,
docs="This task processes data and generates a report."
)
async def process_data(data_path: str) -> str:
return f"Processed {data_path}"
@env_2.task
async def invoke_process_data() -> str:
result = await process_data.override(
resources=flyte.Resources(cpu=4, memory="2Gi"),
env_vars={"MY_VAR": "new_value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_3"),
cache="auto",
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60
)("input.csv")
return result
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
### Parameter interaction
Here is an overview of all task configuration parameters available at each level and how they interact:
| Parameter | `TaskEnvironment` | `@env.task` decorator | `override` on task invocation |
|-------------------------|--------------------|----------------------------|-------------------------------|
| **name** | β Yes (required) | β No | β No |
| **short_name** | β No | β Yes | β Yes |
| **image** | β Yes | β No | β No |
| **resources** | β Yes | β No | β Yes (if not `reusable`) |
| **env_vars** | β Yes | β No | β Yes (if not `reusable`) |
| **secrets** | β Yes | β No | β Yes (if not `reusable`) |
| **cache** | β Yes | β Yes | β Yes |
| **pod_template** | β Yes | β Yes | β Yes |
| **reusable** | β Yes | β No | β Yes |
| **depends_on** | β Yes | β No | β No |
| **description** | β Yes | β No | β No |
| **plugin_config** | β Yes | β No | β No |
| **report** | β No | β Yes | β No |
| **max_inline_io_bytes** | β No | β Yes | β Yes |
| **retries** | β No | β Yes | β Yes |
| **timeout** | β No | β Yes | β Yes |
| **triggers** | β No | β Yes | β No |
| **interruptible** | β Yes | β Yes | β Yes |
| **queue** | β Yes | β Yes | β Yes |
| **docs** | β No | β Yes | β No |
## Task configuration parameters
The full set of parameters available for configuring a task environment, task definition, and task invocation are:
### `name`
* Type: `str` (required)
* Defines the name of the `TaskEnvironment`.
Since it specifies the name *of the environment*, it cannot, logically, be overridden at the `@env.task` decorator or the `task.override()` invocation level.
It is used in conjunction with the name of each `@env.task` function to define the fully-qualified name of the task.
The fully qualified name is always the `TaskEnvironment` name (the one above) followed by a period and then the task function name (the name of the Python function being decorated).
For example:
```
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name:str) -> str:
return f"Hello {name}!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
Here, the name of the TaskEnvironment is `my_env` and the fully qualified name of the task is `my_env.my_task`.
The `TaskEnvironment` name and fully qualified name of a task name are both fixed and cannot be overridden.
### `short_name`
* Type: `str` (required)
* Defines the short name of the task or action (the execution of a task).
Since it specifies the name *of the task*, it is not, logically, available to be set at the ``TaskEnvironment` level.
By default, the short name of a task is the name of the task function (the name of the Python function being decorated).
The short name is used, for example, in parts of the UI.
Overriding it does not change the fully qualified name of the task.
### `image`
* Type: `Union[str, Image, Literal['auto']]`
* Specifies the Docker image to use for the task container.
Can be a URL reference to a Docker image, an **Configure tasks > `Image` object**, or the string `auto`.
If set to `auto`, or if this parameter is not set, the [default image]() will be used.
* Only settable at the `TaskEnvironment` level.
* See **Configure tasks > Container images**.
### `resources`
* Type: `Optional[Resources]`
* Specifies the compute resources, such as CPU and Memory, required by the task environment using a
**Configure tasks > `Resources`** object.
* Can be set at the `TaskEnvironment` level and overridden at the `task.override()` invocation level
(but only if `reuseable` is not in effect).
* See **Configure tasks > Resources**.
### `env_vars`
* Type: `Optional[Dict[str, str]]`
* A dictionary of environment variables to be made available in the task container.
These variables can be used to configure the task at runtime, such as setting API keys or other configuration values.
### `secrets`
* Type: `Optional[SecretRequest]` where `SecretRequest` is an alias for `Union[str, Secret, List[str | Secret]]`
* The secrets to be made available in the task container.
* Can be set at the `TaskEnvironment` level and overridden at the `task.override()` invocation level, but only if `reuseable` is not in effect.
* See **Configure tasks > Secrets** and the API docs for the **Configure tasks > `Secret` object**.
### `cache`
* Type: `Union[CacheRequest]` where `CacheRequest` is an alias for `Literal["auto", "override", "disable", "enabled"] | Cache`.
* Specifies the caching policy to be used for this task.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level
and at the `task.override()` invocation level.
* See **Configure tasks > Caching**.
### `pod_template`
* Type: `Optional[Union[str, kubernetes.client.V1PodTemplate]]`
* A pod template that defines the Kubernetes pod configuration for the task.
A string reference to a named template or a `kubernetes.client.V1PodTemplate` object.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level and the `task.override()` invocation level.
* See **Configure tasks > Pod templates**.
### `reusable`
> [!NOTE]
> The `reusable` setting controls the **Configure tasks > Reusable containers**.
> This feature is only available when running your Flyte code on a Union.ai backend.
> See [one of the Union.ai product variants of this page](/docs/v2/byoc//user-guide/reusable-containers) for details.
### `depends_on`
* Type: `List[Environment]`
* A list of **Configure tasks > `Environment`** objects that this `TaskEnvironment` depends on.
When deploying this `TaskEnvironment`, the system will ensure that any dependencies of the listed `Environment`s are also available.
This is useful when you have a set of task environments that depend on each other.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level or the `task.override()` invocation level.
* See **Configure tasks > Multiple environments**
### `description`
* Type: `Optional[str]`
* A description of the task environment.
This can be used to provide additional context about the task environment, such as its purpose or usage.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level
or the `task.override()` invocation level.
### `plugin_config`
* Type: `Optional[Any]`
* Additional configuration for plugins that can be used with the task environment.
This can include settings for specific plugins that are used in the task environment.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level
or the `task.override()` invocation level.
### `report`
* Type: `bool`
* Whether to generate the HTML report for the task.
If set to `True`, the task will generate an HTML report that can be viewed in the Flyte UI.
* Can only be set at the `@env.task` decorator level,
not at the `TaskEnvironment` level or the `task.override()` invocation level.
* See **Build tasks > Reports**.
### `max_inline_io_bytes`
* Type: `int`
* Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
(e.g., primitives, strings, dictionaries).
Does not apply to **Build tasks > Files and directories**, or **Build tasks > Data classes and structures** (since these are passed by reference).
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
If not set, the default value is `MAX_INLINE_IO_BYTES` (which is 100 MiB).
### `retries`
* Type: `Union[int, RetryStrategy]`
* The number of retries for the task, or a `RetryStrategy` object that defines the retry behavior.
If set to `0`, no retries will be attempted.
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
* See **Configure tasks > Retries and timeouts**.
### `timeout`
* Type: `Union[timedelta, int]`
* The timeout for the task, either as a `timedelta` object or an integer representing seconds.
If set to `0`, no timeout will be applied.
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
* See **Configure tasks > Retries and timeouts**.
### `triggers`
* Type: `Tuple[Trigger, ...] | Trigger`
* A trigger or tuple of triggers that define when the task should be executed.
* Can only be set at the `@env.task` decorator level. It cannot be overridden.
* See **Configure tasks > Triggers**.
### `interruptible`
* Type: `bool`
* Specifies whether the task is interruptible.
If set to `True`, the task can be scheduled on a spot instance, otherwise it can only be scheduled on on-demand instances.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level and at the `task.override()` invocation level.
### `queue`
* Type: `Optional[str]`
* Specifies the queue to which the task should be directed, where the queue is identified by its name.
If set to `None`, the default queue will be used.
Queues serve to point to a specific partitions of your compute infrastructure (for example, a specific cluster in multi-cluster setup).
They are configured as part of your Union/Flyte deployment.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level
and at the `task.override()` invocation level.
### `docs`
* Type: `Optional[Documentation]`
* Documentation for the task, including usage examples and explanations of the task's behavior.
* Can only be set at the `@env.task` decorator level. It cannot be overridden.
## Subpages
- **Configure tasks > Container images**
- **Configure tasks > Resources**
- **Configure tasks > Secrets**
- **Configure tasks > Caching**
- **Configure tasks > Reusable containers**
- **Configure tasks > Pod templates**
- **Configure tasks > Multiple environments**
- **Configure tasks > Retries and timeouts**
- **Configure tasks > Triggers**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/container-images ===
# Container images
The `image` parameter of the **Configure tasks > Container images > `TaskEnvironment`** is used to specify a container image.
Every task defined using that `TaskEnvironment` will run in a container based on that image.
If a `TaskEnvironment` does not specify an `image`, it will use the default Flyte image ([`ghcr.io/flyteorg/flyte:py{python-version}-v{flyte_version}`](https://github.com/orgs/flyteorg/packages/container/package/flyte)).
## Specifying your own image directly
You can directly reference an image by URL in the `image` parameter, like this:
```python
env = flyte.TaskEnvironment(
name="my_task_env",
image="docker.io/myorg/myimage:mytag"
)
```
This works well if you have a pre-built image available in a public registry like Docker Hub or in a private registry that your Union/Flyte instance can access.
## Specifying your own image with the `flyte.Image` object
You can also construct an image programmatically using the `flyte.Image` object.
The `flyte.Image` object provides a fluent interface for building container images with specific dependencies.
You start building your image with on of the `from_` methods:
* `[[Image.from_base()]]`: Start from a pre-built image (Note: The image should be accessible to the imagebuilder).
* `[[Image.from_debian_base()]]`: Start from a [Debian](https://www.debian.org/) based base image, that contains flyte already.
* `[[Image.from_uv_script()]]`: Start with a new image build from a [uv script](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies), slower but easier.
You can then layer on additional components using the `with_` methods:
* `[[Image.with_apt_packages()]]`: Add Debian packages to the image (e.g. apt-get ...).
* `[[Image.with_commands()]]`: Add commands to run in the image (e.g. chmod a+x ... / curl ... / wget).
* `[[Image.with_dockerignore()]]`: Specify a `.dockerignore` file that will be respected durin image build.
* `[[Image.with_env_vars()]]`: Set environment variables in the image.
* `[[Image.with_pip_packages()]]`: Add Python packages to the image (installed via uv pip install ...)
* `[[Image.with_requirements()]]`: Specify a requirements.txt file (all packages will be installed).
* `[[Image.with_source_file()]]`: Specify a source file to include in the image (the file will be copied).
* `[[Image.with_source_folder()]]`: Specify a source folder to include in the image (entire folder will be copied).
* `[[Image.with_uv_project()]]`: Use this with `pyproject.toml` or `uv.lock` based projects.
* `[[Image.with_poetry_project()]]`: Create a new image with the specified `poetry.lock`
* `[[Image.with_workdir()]]`: Specify the working directory for the image.
You can also specify an image in one shot (with no possibility of layering) with:
* `[[Image.from_dockerfile()]]`: Build the final image from a single Dockerfile. (Useful incase of an existing dockerfile).
Additionally, the `Image` class provides:
* `[[Image.clone()]]`: Clone an existing image. (Note: Every operation with_* always clones, every image is immutable. Clone is useful if you need to make a new named image).
* `[[Image.validate()]]`: Validate the image configuration.
* `[[Image.with_local_v2()]]`: Does not add a layer, instead it overrides any existing builder configuration and builds the image locally. See **Configure tasks > Container images > Image building** for more details.
Here are some examples of the most common patterns for building images with `flyte.Image`.
## Example: Defining a custom image with `Image.from_debian_base`
The `[[Image.from_debian_base()]]` provides the default Flyte image as the base.
This image is itself based on the official Python Docker image (specifically `python:{version}-slim-bookworm`) with the addition of the Flyte SDK pre-installed.
Starting there, you can layer additional features onto your image.
For example:
```python
import flyte
import numpy as np
# Define the task environment
env = flyte.TaskEnvironment(
name="my_env",
image = (
flyte.Image.from_debian_base(
name="my-image",
python_version=(3, 13)
# registry="registry.example.com/my-org" # Only needed for local builds
)
.with_apt_packages("libopenblas-dev")
.with_pip_packages("numpy")
.with_env_vars({"OMP_NUM_THREADS": "4"})
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_debian_base.py)
> [!NOTE]
> The `registry` parameter is only needed if you are building the image locally. It is not required when using the Union backend `ImageBuilder`.
> See **Configure tasks > Container images > Image building** for more details.
> [!NOTE]
> Images built with `[[Image.from_debian_base()]]` do not include CA certificates by default, which can cause TLS
> validation errors and block access to HTTPS-based storage such as Amazon S3. Libraries like Polars (e.g., `polars.scan_parquet()`) are particularly affected.
> **Solution:** Add `"ca-certificates"` using `.with_apt_packages()` in your image definition.
## Example: Defining an image based on uv script metadata
Another common technique for defining an image is to use [`uv` inline script metadata](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies) to specify your dependencies right in your Python file and then use the `flyte.Image.from_uv_script()` method to create a `flyte.Image` object.
The `from_uv_script` method starts with the default Flyte image and adds the dependencies specified in the `uv` metadata.
For example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "numpy"
# ]
# main = "main"
# params = "x_list=[1,2,3,4,5,6,7,8,9,10]"
# ///
import flyte
import numpy as np
env = flyte.TaskEnvironment(
name="my_env",
image=flyte.Image.from_uv_script(
__file__,
name="my-image"
# registry="registry.example.com/my-org" # Only needed for local builds
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_uv_script.py)
The advantage of this approach is that the dependencies used when running a script locally and when running it on the Flyte/Union backend are always the same (as long as you use `uv` to run your scripts locally).
This means you can develop and test your scripts in a consistent environment, reducing the chances of encountering issues when deploying to the backend.
In the above example you can see how to use `flyte.init_from_config()` for remote runs and `flyte.init()` for local runs.
Uncomment the `flyte.init()` line (and comment out `flyte.init_from_config()`) to enable local runs.
Do the opposite to enable remote runs.
> [!NOTE]
> When using `uv` metadata in this way, be sure to include the `flyte` package in your `uv` script dependencies.
> This will ensure that `flyte` is installed when running the script locally using `uv run`.
> When running on the Flyte/Union backend, the `flyte` package from the uv script dependencies will overwrite the one included automatically from the default Flyte image.
## Image building
There are two ways that the image can be built:
* If you are running a Flyte OSS instance then the image will be built locally on your machine and pushed to the container registry you specified in the `Image` definition.
* If you are running a Union instance, the image can be built locally, as with Flyte OSS, or using the Union `ImageBuilder`, which runs remotely on Union's infrastructure.
### Configuring the `builder`
**Getting started > Local setup**, we discussed the `image.builder` property in the `config.yaml`.
For Flyte OSS instances, this property must be set to `local`.
For Union instances, this property can be set to `remote` to use the Union `ImageBuilder`, or `local` to build the image locally on your machine.
### Local image building
When `image.builder` in the `config.yaml` is set to `local`, `flyte.run()` does the following:
* Builds the Docker image using your local Docker installation, installing the dependencies specified in the `uv` inline script metadata.
* Pushes the image to the container registry you specified.
* Deploys your code to the backend.
* Kicks off the execution of your workflow
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
> [!NOTE]
> Above, we used `registry="ghcr.io/my_gh_org"`.
>
> Be sure to change `ghcr.io/my_gh_org` to the URL of your actual container registry.
You must ensure that:
* Docker is running on your local machine.
* You have successfully run `docker login` to that registry from your local machine (For example GitHub uses the syntax `echo $GITHUB_TOKEN | docker login ghcr.io -u USERNAME --password-stdin`)
* Your Union/Flyte installation has read access to that registry.
> [!NOTE]
> If you are using the GitHub container registry (`ghcr.io`)
> note that images pushed there are private by default.
> You may need to go to the image URI, click **Package Settings**, and change the visibility to public in order to access the image.
>
> Other registries (such as Docker Hub) require that you pre-create the image repository before pushing the image.
> In that case you can set it to public when you create it.
>
> Public images are on the public internet and should only be used for testing purposes.
> Do not place proprietary code in public images.
### Remote `ImageBuilder`
`ImageBuilder` is a service provided by Union that builds container images on Union's infrastructure and provides an internal container registry for storing the built images.
When `image.builder` in the `config.yaml` is set to `remote` (and you are running Union.ai), `flyte.run()` does the following:
* Builds the Docker image on your Union instance with `ImageBuilder`.
* Pushes the image to a registry
* If you did not specify a `registry` in the `Image` definition, it pushes to the internal registry in your Union instance.
* If you did specify a `registry`, it pushes to that registry. Be sure to also set the `registry_secret` parameter in the `Image` definition to enable `ImageBuilder` to authenticate to that registry (see **Configure tasks > Container images > Image building > Remote `ImageBuilder` > ImageBuilder with external registries**).
* Deploys your code to the backend.
* Kicks off the execution of your workflow.
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
There is no set up of Docker nor any other local configuration required on your part.
#### ImageBuilder with external registries
If you are want to push the images built by `ImageBuilder` to an external registry, you can do this by setting the `registry` parameter in the `Image` object.
You will also need to set the `registry_secret` parameter to provide the secret needed to push and pull images to the private registry.
For example:
```python
# Add registry credentials so the Union remote builder can pull the base image
# and push the resulting image to your private registry.
image=flyte.Image.from_debian_base(
name="my-image",
base_image="registry.example.com/my-org/my-private-image:latest",
registry="registry.example.com/my-org"
registry_secret="my-secret"
)
# Reference the same secret in the TaskEnvironment so Flyte can pull the image at runtime.
env = flyte.TaskEnvironment(
name="my_task_env",
image=image,
secrets="my-secret"
)
```
The value of the `registry_secret` parameter must be the name of a Flyte secret of type `image_pull` that contains the credentials needed to access the private registry. It must match the name specified in the `secrets` parameter of the `TaskEnvironment` so that Flyte can use it to pull the image at runtime.
To create an `image_pull` secret for the remote builder and the task environment, run the following command:
```shell
$ flyte create secret --type image_pull my-secret --from-file ~/.docker/config.json
```
The format of this secret matches the standard Kubernetes [image pull secret](https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/#log-in-to-docker-hub), and should look like this:
```json
{
"auths": {
"registry.example.com": {
"auth": "base64-encoded-auth"
}
}
}
```
> [!NOTE]
> The `auth` field contains the base64-encoded credentials for your registry (username and password or token).
### Install private PyPI packages
To install Python packages from a private PyPI index (for example, from GitHub), you can mount a secret to the image layer.
This allows your build to authenticate securely during dependency installation.
For example:
```python
private_package = "git+https://$GITHUB_PAT@github.com/pingsutw/flytex.git@2e20a2acebfc3877d84af643fdd768edea41d533"
image = (
Image.from_debian_base()
.with_apt_packages("git")
.with_pip_packages(private_package, pre=True, secret_mounts=Secret("GITHUB_PAT"))
)
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/resources ===
# Resources
Task resources specify the computational limits and requests (CPU, memory, GPU, storage) that will be allocated to each task's container during execution.
To specify resource requirements for your task, instantiate a `Resources` object with the desired parameters and assign it to either
the `resources` parameter of the `TaskEnvironment` or the `resources` parameter of the `override` function (for invocation overrides).
Every task defined using that `TaskEnvironment` will run with the specified resources.
If a specific task has its own `resources` defined in the decorator, it will override the environment's resources for that task only.
If neither `TaskEnvironment` nor the task decorator specifies `resources`, the default resource allocation will be used.
## Resources data class
The `Resources` data class provides the following initialization parameters:
```python
resources = flyte.Resources(
cpu: Union[int, float, str, Tuple[Union[int, float, str], Union[int, float, str]], None] = None,
memory: Union[str, Tuple[str, str], None] = None,
gpu: Union[str, int, flyte.Device, None] = None,
disk: Union[str, None] = None,
shm: Union[str, Literal["auto"], None] = None
)
```
Each parameter is optional and allows you to specify different types of resources:
- **`cpu`**: CPU allocation - can be a number, string, or tuple for request/limit ranges (e.g., `2` or `(2, 4)`).
- **`memory`**: Memory allocation - string with units (e.g., `"4Gi"`) or tuple for ranges.
- **`gpu`**: GPU allocation - accelerator string (e.g., `"A100:2"`), count, or `Device` (a **Configure tasks > Resources > GPU resources**, **Configure tasks > Resources > TPU resources** or **Configure tasks > Resources > Custom device specifications**).
- **`disk`**: Ephemeral storage - string with units (e.g., `"10Gi"`).
- **`shm`**: Shared memory - string with units or `"auto"` for automatic sizing (e.g., `"8Gi"` or `"auto"`).
## Examples
### Usage in TaskEnvironment
Here's a complete example of defining a TaskEnvironment with resource specifications for a machine learning training workload:
```
import flyte
# Define a TaskEnvironment for ML training tasks
env = flyte.TaskEnvironment(
name="ml-training",
resources=flyte.Resources(
cpu=("2", "4"), # Request 2 cores, allow up to 4 cores for scaling
memory=("2Gi", "12Gi"), # Request 2 GiB, allow up to 12 GiB for large datasets
disk="50Gi", # 50 GiB ephemeral storage for checkpoints
shm="8Gi" # 8 GiB shared memory for efficient data loading
)
)
# Use the environment for tasks
@env.task
async def train_model(dataset_path: str) -> str:
# This task will run with flexible resource allocation
return "model trained"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py)
### Usage in a task-specific override
```
# Demonstrate resource override at task invocation level
@env.task
async def heavy_training_task() -> str:
return "heavy model trained with overridden resources"
@env.task
async def main():
# Task using environment-level resources
result = await train_model("data.csv")
print(result)
# Task with overridden resources at invocation time
result = await heavy_training_task.override(
resources=flyte.Resources(
cpu="4",
memory="24Gi",
disk="100Gi",
shm="16Gi"
)
)()
print(result)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py)
## Resource types
### CPU resources
CPU can be specified in several formats:
```python
# String formats (Kubernetes-style)
flyte.Resources(cpu="500m") # 500 milliCPU (0.5 cores)
flyte.Resources(cpu="2") # 2 CPU cores
flyte.Resources(cpu="1.5") # 1.5 CPU cores
# Numeric formats
flyte.Resources(cpu=1) # 1 CPU core
flyte.Resources(cpu=0.5) # 0.5 CPU cores
# Request and limit ranges
flyte.Resources(cpu=("1", "2")) # Request 1 core, limit to 2 cores
flyte.Resources(cpu=(1, 4)) # Request 1 core, limit to 4 cores
```
### Memory resources
Memory specifications follow Kubernetes conventions:
```python
# Standard memory units
flyte.Resources(memory="512Mi") # 512 MiB
flyte.Resources(memory="1Gi") # 1 GiB
flyte.Resources(memory="2Gi") # 2 GiB
flyte.Resources(memory="500M") # 500 MB (decimal)
flyte.Resources(memory="1G") # 1 GB (decimal)
# Request and limit ranges
flyte.Resources(memory=("1Gi", "4Gi")) # Request 1 GiB, limit to 4 GiB
```
### GPU resources
Flyte supports various GPU types and configurations:
#### Simple GPU allocation
```python
# Basic GPU count
flyte.Resources(gpu=1) # 1 GPU (any available type)
flyte.Resources(gpu=4) # 4 GPUs
# Specific GPU types with quantity
flyte.Resources(gpu="T4:1") # 1 NVIDIA T4 GPU
flyte.Resources(gpu="A100:2") # 2 NVIDIA A100 GPUs
flyte.Resources(gpu="H100:8") # 8 NVIDIA H100 GPUs
```
#### Advanced GPU configuration
You can also use the `GPU` helper class for more detailed configurations:
```python
# Using the GPU helper function
gpu_config = flyte.GPU(device="A100", quantity=2)
flyte.Resources(gpu=gpu_config)
# GPU with memory partitioning (A100 only)
partitioned_gpu = flyte.GPU(
device="A100",
quantity=1,
partition="1g.5gb" # 1/7th of A100 with 5GB memory
)
flyte.Resources(gpu=partitioned_gpu)
# A100 80GB with partitioning
large_partition = flyte.GPU(
device="A100 80G",
quantity=1,
partition="7g.80gb" # Full A100 80GB
)
flyte.Resources(gpu=large_partition)
```
#### Supported GPU types
- **T4**: Entry-level training and inference
- **L4**: Optimized for AI inference
- **L40s**: High-performance compute
- **A100**: High-end training and inference (40GB)
- **A100 80G**: High-end training with more memory (80GB)
- **H100**: Latest generation, highest performance
### Custom device specifications
You can also define custom devices if your infrastructure supports them:
```python
# Custom device configuration
custom_device = flyte.Device(
device="custom_accelerator",
quantity=2,
partition="large"
)
resources = flyte.Resources(gpu=custom_device)
```
### TPU resources
For Google Cloud TPU workloads you can specify TPU resources using the `TPU` helper class:
```python
# TPU v5p configuration
tpu_config = flyte.TPU(device="V5P", partition="2x2x1")
flyte.Resources(gpu=tpu_config) # Note: TPUs use the gpu parameter
# TPU v6e configuration
tpu_v6e = flyte.TPU(device="V6E", partition="4x4")
flyte.Resources(gpu=tpu_v6e)
```
### Storage resources
Flyte provides two types of storage resources for tasks: ephemeral disk storage and shared memory.
These resources are essential for tasks that need temporary storage for processing data, caching intermediate results, or sharing data between processes.
#### Disk storage
Ephemeral disk storage provides temporary space for your tasks to store intermediate files, downloaded datasets, model checkpoints, and other temporary data. This storage is automatically cleaned up when the task completes.
```python
flyte.Resources(disk="10Gi") # 10 GiB ephemeral storage
flyte.Resources(disk="100Gi") # 100 GiB ephemeral storage
flyte.Resources(disk="1Ti") # 1 TiB for large-scale data processing
# Common use cases
flyte.Resources(disk="50Gi") # ML model training with checkpoints
flyte.Resources(disk="200Gi") # Large dataset preprocessing
flyte.Resources(disk="500Gi") # Video/image processing workflows
```
#### Shared memory
Shared memory (`/dev/shm`) is a high-performance, RAM-based storage area that can be shared between processes within the same container. It's particularly useful for machine learning workflows that need fast data loading and inter-process communication.
```python
flyte.Resources(shm="1Gi") # 1 GiB shared memory (/dev/shm)
flyte.Resources(shm="auto") # Auto-sized shared memory
flyte.Resources(shm="16Gi") # Large shared memory for distributed training
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/secrets ===
# Secrets
Flyte secrets enable you to securely store and manage sensitive information, such as API keys, passwords, and other credentials.
Secrets reside in a secret store on the data plane of your Union/Flyte backend.
You can create, list, and delete secrets in the store using the Flyte CLI or SDK.
Secrets in the store can be accessed and used within your workflow tasks, without exposing any cleartext values in your code.
## Creating a literal string secret
You can create a secret using the **Flyte CLI > flyte > flyte create > flyte create secret** command like this:
```shell
flyte create secret MY_SECRET_KEY my_secret_value
```
This will create a secret called `MY_SECRET_KEY` with the value `my_secret_value`.
This secret will be scoped to your entire organization.
It will be available across all projects and domains in your organization.
See the **Configure tasks > Secrets > Scoping secrets** section below for more details.
See **Configure tasks > Secrets > Using a literal string secret** for how to access the secret in your task code.
## Creating a file secret
You can also create a secret by specifying a local file:
```shell
flyte create secret MY_SECRET_KEY --from-file /local/path/to/my_secret_file
```
In this case, when accessing the secret in your task code, you will need to **Configure tasks > Secrets > Using a file secret**.
## Scoping secrets
When you create a secret without specifying a project or domain, as we did above, the secret is scoped to the organization level.
This means that the secret will be available across all projects and domains in the organization.
You can optionally specify either or both of the `--project` and `--domain` flags to restrict the scope of the secret to:
* A specific project (across all domains)
* A specific domain (across all project)
* A specific project and a specific domain.
For example, to create a secret that it is only available in `my_project/development`, you would execute the following command:
```shell
flyte create secret --project my_project --domain development MY_SECRET_KEY my_secret_value
```
## Listing secrets
You can list existing secrets with the **Flyte CLI > flyte > flyte get > flyte get secret** command.
For example, the following command will list all secrets in the organization:
```shell
$ flyte get secret
```
Specifying either or both of the `--project` and `--domain` flags will list the secrets that are **only** available in that project and/or domain.
For example, to list the secrets that are only available in `my_project` and domain `development`, you would run:
```shell
flyte get secret --project my_project --domain development
```
## Deleting secrets
To delete a secret, use the **Flyte CLI > flyte > flyte delete > flyte delete secret** command:
```shell
flyte delete secret MY_SECRET_KEY
```
## Using a literal string secret
To use a literal string secret, specify it in the `TaskEnvironment` along with the name of the environment variable into which it will be injected.
You can then access it using `os.getenv()` in your task code.
For example:
```
env_1 = flyte.TaskEnvironment(
name="env_1",
secrets=[
flyte.Secret(key="my_secret", as_env_var="MY_SECRET_ENV_VAR"),
]
)
@env_1.task
def task_1():
my_secret_value = os.getenv("MY_SECRET_ENV_VAR")
print(f"My secret value is: {my_secret_value}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py)
## Using a file secret
To use a file secret, specify it in the `TaskEnvironment` along with the `mount="/etc/flyte/secrets"` argument (with that precise value).
The file will be mounted at `/etc/flyte/secrets/`.
For example:
```
env_2 = flyte.TaskEnvironment(
name="env_2",
secrets=[
flyte.Secret(key="my_secret", mount="/etc/flyte/secrets"),
]
)
@env_2.task
def task_2():
with open("/etc/flyte/secrets/my_secret", "r") as f:
my_secret_file_content = f.read()
print(f"My secret file content is: {my_secret_file_content}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py)
> [!NOTE]
> Currently, to access a file secret you must specify a `mount` parameter value of `"/etc/flyte/secrets"`.
> This fixed path is the directory in which the secret file will be placed.
> The name of the secret file will be equal to the key of the secret.
> [!NOTE]
> A `TaskEnvironment` can only access a secret if the scope of the secret includes the project and domain where the `TaskEnvironment` is deployed.
> [!WARNING]
> Do not return secret values from tasks, as this will expose secrets to the control plane.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/caching ===
# Caching
Flyte 2 provides intelligent **task output caching** that automatically avoids redundant computation by reusing previously computed task results.
> [!NOTE]
> Caching works at the task level and caches complete task outputs.
> For function-level checkpointing and resumption *within tasks*, see **Build tasks > Traces**.
## Overview
By default, caching is disabled.
If caching is enabled for a task, then Flyte determines a **cache key** for the task.
The key is composed of the following:
* Final inputs: The set of inputs after removing any specified in the `ignored_inputs`.
* Task name: The fully-qualified name of the task.
* Interface hash: A hash of the task's input and output types.
* Cache version: The cache version string.
If the cache behavior is set to `"auto"`, the cache version is automatically generated using a hash of the task's source code (or according to the custom policy if one is specified).
If the cache behavior is set to `"override"`, the cache version can be specified explicitly using the `version_override` parameter.
When the task runs, Flyte checks if a cache entry exists for the key.
If found, the cached result is returned immediately instead of re-executing the task.
## Basic caching usage
Flyte 2 supports three main cache behaviors:
### `"auto"` - Automatic versioning
```
@env.task(cache=flyte.Cache(behavior="auto"))
async def auto_versioned_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
With `behavior="auto"`, the cache version is automatically generated based on the function's source code.
If you change the function implementation, the cache is automatically invalidated.
- **When to use**: Development and most production scenarios.
- **Cache invalidation**: Automatic when function code changes.
- **Benefits**: Zero-maintenance caching that "just works".
You can also use the direct string shorthand:
```
@env.task(cache="auto")
async def auto_versioned_task_2(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `"override"`
With `behavior="override"`, you can specify a custom cache key in the `version_override` parameter.
Since the cache key is fixed as part of the code, it can be manually changed when you need to invalidate the cache.
```
@env.task(cache=flyte.Cache(behavior="override", version_override="v1.2"))
async def manually_versioned_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
- **When to use**: When you need explicit control over cache invalidation.
- **Cache invalidation**: Manual, by changing `version_override`.
- **Benefits**: Stable caching across code changes that don't affect logic.
### `"disable"` - No caching
To explicitly disable caching, use the `"disable"` behavior.
**This is the default behavior.**
```
@env.task(cache=flyte.Cache(behavior="disable"))
async def always_fresh_task(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
- **When to use**: Non-deterministic functions, side effects, or always-fresh data.
- **Cache invalidation**: N/A - never cached.
- **Benefits**: Ensures execution every time.
You can also use the direct string shorthand:
```
@env.task(cache="disable")
async def always_fresh_task_2(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Advanced caching configuration
### Ignoring specific inputs
Sometimes you want to cache based on some inputs but not others:
```
@env.task(cache=flyte.Cache(behavior="auto", ignored_inputs=("debug_flag",)))
async def selective_caching(data: str, debug_flag: bool) -> str:
if debug_flag:
print(f"Debug: transforming {data}")
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**This is useful for**:
- Debug flags that don't affect computation
- Logging levels or output formats
- Metadata that doesn't impact results
### Cache serialization
Cache serialization ensures that only one instance of a task runs at a time for identical inputs:
```
@env.task(cache=flyte.Cache(behavior="auto", serialize=True))
async def expensive_model_training(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**When to use serialization**:
- Very expensive computations (model training, large data processing)
- Shared resources that shouldn't be accessed concurrently
- Operations where multiple parallel executions provide no benefit
**How it works**:
1. First execution acquires a reservation and runs normally.
2. Concurrent executions with identical inputs wait for the first to complete.
3. Once complete, all waiting executions receive the cached result.
4. If the running execution fails, another waiting execution takes over.
### Salt for cache key variation
Use `salt` to vary cache keys without changing function logic:
```
@env.task(cache=flyte.Cache(behavior="auto", salt="experiment_2024_q4"))
async def experimental_analysis(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**`salt` is useful for**:
- A/B testing with identical code.
- Temporary cache namespaces for experiments.
- Environment-specific cache isolation.
## Cache policies
For `behavior="auto"`, Flyte uses cache policies to generate version hashes.
### Function body policy (default)
The default `FunctionBodyPolicy` generates cache versions from the function's source code:
```
from flyte._cache import FunctionBodyPolicy
@env.task(cache=flyte.Cache(
behavior="auto",
policies=[FunctionBodyPolicy()] # This is the default. Does not actually need to be specified.
))
async def code_sensitive_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### Custom cache policies
You can implement custom cache policies by following the `CachePolicy` protocol:
```
from flyte._cache import CachePolicy
class DatasetVersionPolicy(CachePolicy):
def get_version(self, salt: str, params) -> str:
# Generate version based on custom logic
dataset_version = get_dataset_version()
return f"{salt}_{dataset_version}"
@env.task(cache=flyte.Cache(behavior="auto", policies=[DatasetVersionPolicy()]))
async def dataset_dependent_task(data: str) -> str:
# Cache invalidated when dataset version changes
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Caching configuration at different levels
You can configure caching at three levels: `TaskEnvironment` definition, `@env.task` decorator, and task invocation.
### `TaskEnvironment` Level
You can configure caching at the `TaskEnvironment` level.
This will set the default cache behavior for all tasks defined using that environment.
For example:
```
cached_env = flyte.TaskEnvironment(
name="cached_environment",
cache=flyte.Cache(behavior="auto") # Default for all tasks
)
@cached_env.task # Inherits auto caching from environment
async def inherits_caching(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `@env.task` decorator level
By setting the cache parameter in the `@env.task` decorator, you can override the environment's default cache behavior for specific tasks:
```
@cached_env.task(cache=flyte.Cache(behavior="disable")) # Override environment default
async def decorator_caching(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `task.override` level
By setting the cache parameter in the `task.override` method, you can override the cache behavior for specific task invocations:
```
@env.task
async def override_caching_on_call(data: str) -> str:
# Create an overridden version and call it
overridden_task = inherits_caching.override(cache=flyte.Cache(behavior="disable"))
return await overridden_task(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Runtime cache control
You can also force cache invalidation for a specific run:
```python
# Disable caching for this specific execution
run = flyte.with_runcontext(overwrite_cache=True).run(my_cached_task, data="test")
```
## Project and domain cache isolation
Caches are automatically isolated by:
- **Project**: Tasks in different projects have separate cache namespaces.
- **Domain**: Development, staging, and production domains maintain separate caches.
## Local development caching
When running locally, Flyte maintains a local cache:
```python
# Local execution uses ~/.flyte/local-cache/
flyte.init() # Local mode
result = flyte.run(my_cached_task, data="test")
```
Local cache behavior:
- Stored in `~/.flyte/local-cache/` directory
- No project/domain isolation (since running locally)
- Can be cleared with `flyte local-cache clear`
- Disabled by setting `FLYTE_LOCAL_CACHE_ENABLED=false`
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/reusable-containers ===
# Reusable containers
By default, each task execution in Flyte and Union runs in a fresh container instance that is created just for that execution and then discarded.
With reusable containers, the same container can be reused across multiple executions and tasks.
This approach reduces start up overhead and improves resource efficiency.
> [!NOTE]
> The reusable container feature is only available when running your Flyte code on a Union backend.
> See [one of the Union.ai product variants of this page](/docs/v2/byoc//user-guide/reusable-containers) for details.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/pod-templates ===
# Pod templates
The `pod_template` parameter in `TaskEnvironment` (and in the @env.task decorator, if you are overriding) allows you to customize the Kubernetes pod specification that will be used to run your tasks.
This provides fine-grained control over the underlying Kubernetes resources, enabling you to configure advanced pod settings like image pull secrets, environment variables, labels, annotations, and other pod-level configurations.
## Overview
Pod templates in Flyte allow you to:
- **Configure pod metadata**: Set custom labels and annotations for your pods.
- **Specify image pull secrets**: Access private container registries.
- **Set environment variables**: Configure container-level environment variables.
- **Customize pod specifications**: Define advanced Kubernetes pod settings.
- **Control container configurations**: Specify primary container settings.
The `pod_template` parameter accepts either a string reference or a `PodTemplate` object that defines the complete pod specification.
## Basic usage
Here's a complete example showing how to use pod templates with a `TaskEnvironment`:
```
# /// script
# requires-python = "==3.12"
# dependencies = [
# "flyte==2.0.0b31",
# "kubernetes"
# ]
# ///
import flyte
from kubernetes.client import (
V1Container,
V1EnvVar,
V1LocalObjectReference,
V1PodSpec,
)
# Create a custom pod template
pod_template = flyte.PodTemplate(
primary_container_name="primary", # Name of the main container
labels={"lKeyA": "lValA"}, # Custom pod labels
annotations={"aKeyA": "aValA"}, # Custom pod annotations
pod_spec=V1PodSpec( # Kubernetes pod specification
containers=[
V1Container(
name="primary",
env=[V1EnvVar(name="hello", value="world")] # Environment variables
)
],
image_pull_secrets=[ # Access to private registries
V1LocalObjectReference(name="regcred-test")
],
),
)
# Use the pod template in a TaskEnvironment
env = flyte.TaskEnvironment(
name="hello_world",
pod_template=pod_template, # Apply the custom pod template
image=flyte.Image.from_uv_script(__file__, name="flyte", pre=True),
)
@env.task
async def say_hello(data: str) -> str:
return f"Hello {data}"
@env.task
async def say_hello_nested(data: str = "default string") -> str:
return await say_hello(data=data)
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(say_hello_nested, data="hello world")
print(result.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/pod-templates/pod_template.py)
## PodTemplate components
The `PodTemplate` class provides the following parameters for customizing your pod configuration:
```python
pod_template = flyte.PodTemplate(
primary_container_name: str = "primary",
pod_spec: Optional[V1PodSpec] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None
)
```
### Parameters
- **`primary_container_name`** (`str`, default: `"primary"`): Specifies the name of the main container that will run your task code. This must match the container name defined in your pod specification.
- **`pod_spec`** (`Optional[V1PodSpec]`): A standard Kubernetes `V1PodSpec` object that defines the complete pod specification. This allows you to configure any pod-level setting including containers, volumes, security contexts, node selection, and more.
- **`labels`** (`Optional[Dict[str, str]]`): Key-value pairs used for organizing and selecting pods. Labels are used by Kubernetes selectors and can be queried to filter and manage pods.
- **`annotations`** (`Optional[Dict[str, str]]`): Additional metadata attached to the pod that doesn't affect pod scheduling or selection. Annotations are typically used for storing non-identifying information like deployment revisions, contact information, or configuration details.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/multiple-environments ===
# Multiple environments
In many applications, different tasks within your workflow may require different configurations.
Flyte enables you to manage this complexity by allowing multiple environments within a single workflow.
Multiple environments are useful when:
- Different tasks in your workflow need different dependencies.
- Some tasks require specific CPU/GPU or memory configurations.
- A task requires a secret that other tasks do not (and you want to limit exposure of the secret value).
- You're integrating specialized tools that have conflicting requirements.
## Constraints on multiple environments
To use multiple environments in your workflow you define multiple `TaskEnvironment` instances, each with its own configuration, and then assign tasks to their respective environments.
There are, however, two additional constraints that you must take into account.
If `task_1` in environment `env_1` calls a `task_2` in environment `env_2`, then:
1. `env_1` must declare a deployment-time dependency on `env_2` in the `depends_on` parameter of `TaskEnvironment` that defines `env_1`.
2. The image used in the `TaskEnvironment` of `env_1` must include all dependencies of the module containing the `task_2` (unless `task_2` is invoked as a remote task).
### Task `depends_on` constraints
The `depends_on` parameter in `TaskEnvironment` is used to provide deployment-time dependencies by establishing a relationship between one `TaskEnvironment` and another.
The system uses this information to determine which environments (and, specifically which images) need to be built in order to be able to run the code.
On `flyte run` (or `flyte deploy`), the system walks the tree defined by the `depends_on` relationships, starting with the environment of the task being invoked (or the environment being deployed, in the case of `flyte deploy`), and prepares each required environment.
Most importantly, it ensures that the container images need for all required environments are available (and if not, it builds them).
This deploy-time determination of what to build is important because it means that for any given `run` or `deploy`, only those environments that are actually required are built.
The alternative strategy of building all environments defined in the set of deployed code can lead to unnecessary and expensive builds, especially when iterating on code.
### Dependency inclusion constraints
When a parent task invokes a child task in a different environment, the container image of the parent task environment must include all dependencies used by the child task.
This is necessary because of the way task invocation works in Flyte:
- When a child task is invoked by function name, that function, necessarily, has to be imported into the parent tasks's Python environment.
- This results in all the dependencies of the child task function also being imported.
- But, nonetheless, the actual execution of the child task occurs in its own environment.
To avoid this requirement, you can invoke a task in another environment _remotely_.
## Example
The following example is a (very) simple mock of an AlphaFold2 pipeline.
It demonstrates a workflow with three tasks, each in its own environment.
The example project looks like this:
```bash
βββ msa/
β βββ __init__.py
β βββ run.py
βββ fold/
β βββ __init__.py
β βββ run.py
βββ __init__.py
βββ main.py
```
(The source code for this example can be found here:[AlphaFold2 mock example](https://github.com/unionai/unionai-examples/tree/main/v2/user-guide/task-configuration/multiple-environments/af2))
In file `msa/run.py` we define the task `run_msa`, which mocks the multiple sequence alignment step of the process:
```python
import flyte
from flyte.io import File
MSA_PACKAGES = ["pytest"]
msa_image = flyte.Image.from_debian_base().with_pip_packages(*MSA_PACKAGES)
msa_env = flyte.TaskEnvironment(name="msa_env", image=msa_image)
@msa_env.task
def run_msa(x: str) -> File:
f = File.new_remote()
with f.open_sync("w") as fp:
fp.write(x)
return f
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/msa/run.py)
* A dedicated image (`msa_image`) is built using the `MSA_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`msa_env`) is defined for the task, using `msa_image`.
* The task is defined within the context of the `msa_env` environment.
In file `fold/run.py` we define the task `run_fold`, which mocks the fold step of the process:
```python
import flyte
from flyte.io import File
FOLD_PACKAGES = ["ruff"]
fold_image = flyte.Image.from_debian_base().with_pip_packages(*FOLD_PACKAGES)
fold_env = flyte.TaskEnvironment(name="fold_env", image=fold_image)
@fold_env.task
def run_fold(sequence: str, msa: File) -> list[str]:
with msa.open_sync("r") as f:
msa_content = f.read()
return [msa_content, sequence]
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/fold/run.py)
* A dedicated image (`fold_image`) is built using the `FOLD_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`fold_env`) is defined for the task, using `fold_image`.
* The task is defined within the context of the `fold_env` environment.
Finally, in file `main.py` we define the task `main` that ties everything together into a workflow.
We import the required modules and functions:
```
import logging
import pathlib
from fold.run import fold_env, fold_image, run_fold
from msa.run import msa_env, MSA_PACKAGES, run_msa
import flyte
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Notice that we import
* The task functions that we will be calling: `run_fold` and `run_msa`.
* The environments of those tasks: `fold_env` and `msa_env`.
* The dependency list of the `run_msa` task: `MSA_PACKAGES`
* The image of the `run_fold` task: `fold_image`
We then assemble the image and the environment:
```
main_image = fold_image.with_pip_packages(*MSA_PACKAGES)
env = flyte.TaskEnvironment(
name="multi_env",
depends_on=[fold_env, msa_env],
image=main_image,
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
The image for the `main` task (`main_image`) is built by starting with `fold_image` (the image for the `run_fold` task) and adding `MSA_PACKAGES` (the dependency list for the `run_msa` task).
This ensures that `main_image` includes all dependencies needed by both the `run_fold` and `run_msa` tasks.
The environment for the `main` task is defined with:
* The image `main_image`. This ensures that the `main` task has all the dependencies it needs.
* A depends_on list that includes both `fold_env` and `msa_env`. This establishes the deploy-time dependencies on those environments.
Finally, we define the `main` task itself:
```
@env.task
def main(sequence: str) -> list[str]:
"""Given a sequence, outputs files containing the protein structure
This requires model weights + gpus + large database on aws fsx lustre
"""
print(f"Running AlphaFold2 for sequence: {sequence}")
msa = run_msa(sequence)
print(f"MSA result: {msa}, passing to fold task")
results = run_fold(sequence, msa)
print(f"Fold results: {results}")
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Here we call, in turn, the `run_msa` and `run_fold` tasks.
Since we call them directly rather than as remote tasks, we had to ensure that `main_image` includes all dependencies needed by both tasks.
The final piece of the puzzle is the `if __name__ == "__main__":` block that allows us to run the `main` task on the configured Flyte backend:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, "AAGGTTCCAA")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Now you can run the workflow with:
```bash
python main.py
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/retries-and-timeouts ===
# Retries and timeouts
Flyte provides robust error handling through configurable retry strategies and timeout controls.
These parameters help ensure task reliability and prevent resource waste from runaway processes.
## Retries
The `retries` parameter controls how many times a failed task should be retried before giving up.
A "retry" is any attempt after the initial attempt.
In other words, `retries=3` means the task may be attempted up to 4 times in total (1 initial + 3 retries).
The `retries` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the examples below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py).
### Retry example
First we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import flyte
env = flyte.TaskEnvironment(name="my-env")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
Then we configure our task to retry up to 3 times if it fails (for a total of 4 attempts). We also define the driver task `main` that calls the `retry` task:
```
@env.task(retries=3)
async def retry() -> str:
if random.random() < 0.7: # 70% failure rate
raise Exception("Task failed!")
return "Success!"
@env.task
async def main() -> list[str]:
results = []
try:
results.append(await retry())
except Exception as e:
results.append(f"Failed: {e}")
try:
results.append(await retry.override(retries=5)())
except Exception as e:
results.append(f"Failed: {e}")
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
Note that we call `retry` twice: first without any `override`, and then with an `override` to increase the retries to 5 (for a total of 6 attempts).
Finally, we configure flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
## Timeouts
The `timeout` parameter sets limits on how long a task can run, preventing resource waste from stuck processes.
It supports multiple formats for different use cases.
The `timeout` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the example below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py).
### Timeout example
First, we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import asyncio
import flyte
from flyte import Timeout
env = flyte.TaskEnvironment(name="my-env")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Our first task sets a timeout using seconds as an integer:
```
@env.task(timeout=60) # 60 seconds
async def timeout_seconds() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_seconds completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
We can also set a timeout using a `timedelta` object for more readable durations:
```
@env.task(timeout=timedelta(minutes=1))
async def timeout_timedelta() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_timedelta completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
You can also set separate timeouts for maximum execution time and maximum queue time using the `Timeout` class:
```
@env.task(timeout=Timeout(
max_runtime=timedelta(minutes=1), # Max execution time per attempt
max_queued_time=timedelta(minutes=1) # Max time in queue before starting
))
async def timeout_advanced() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
You can also combine retries and timeouts for resilience and resource control:
```
@env.task(
retries=3,
timeout=Timeout(
max_runtime=timedelta(minutes=1),
max_queued_time=timedelta(minutes=1)
)
)
async def timeout_with_retry() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Here we specify:
- Up to 3 retry attempts.
- Each attempt times out after 1 minute.
- Task fails if queued for more than 1 minute.
- Total possible runtime: 1 minute queue + (1 minute Γ 3 attempts).
We define the `main` driver task that calls all the timeout tasks concurrently and returns their outputs as a list. The return value for failed tasks will indicate failure:
```
@env.task
async def main() -> list[str]:
tasks = [
timeout_seconds(),
timeout_seconds.override(timeout=120)(), # Override to 120 seconds
timeout_timedelta(),
timeout_advanced(),
timeout_with_retry(),
]
results = await asyncio.gather(*tasks, return_exceptions=True)
output = []
for r in results:
if isinstance(r, Exception):
output.append(f"Failed: {r}")
else:
output.append(r)
return output
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Note that we also demonstrate overriding the timeout for `timeout_seconds` to 120 seconds when calling it.
Finally, we configure Flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Proper retry and timeout configuration ensures your Flyte workflows are both reliable and efficient, handling transient failures gracefully while preventing resource waste.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-configuration/triggers ===
# Triggers
Triggers allow you to automate and parameterize an execution by scheduling its start time and providing overrides for its task inputs.
Currently, only **schedule triggers** are supported.
This type of trigger runs a task based on a Cron expression or a fixed-rate schedule.
Support is coming for other trigger types, such as:
* Webhook triggers: Hit an API endpoint to run your task.
* Artifact triggers: Run a task when a specific artifact is produced.
## Triggers are set in the task decorator
A trigger is created by setting the `triggers` parameter in the task decorator to a `flyte.Trigger` object or a list of such objects (triggers are not settable at the `TaskEnvironment` definition or `task.override` levels).
Here is a simple example:
```
import flyte
from datetime import datetime, timezone
env = flyte.TaskEnvironment(name="trigger_env")
@env.task(triggers=flyte.Trigger.hourly()) # Every hour
def hourly_task(trigger_time: datetime, x: int = 1) -> str:
return f"Hourly example executed at {trigger_time.isoformat()} with x={x}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
Here we use a predefined schedule trigger to run the `hourly_task` every hour.
Other predefined triggers can be used similarly (see **Configure tasks > Triggers > Predefined schedule triggers** below).
If you want full control over the trigger behavior, you can define a trigger using the `flyte.Trigger` class directly.
## `flyte.Trigger`
The `Trigger` class allows you to define custom triggers with full control over scheduling and execution behavior. It has the following signature:
```
flyte.Trigger(
name,
automation,
description="",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Core Parameters
**`name: str`** (required)
The unique identifier for the trigger within your project/domain.
**`automation: Union[Cron, FixedRate]`** (required)
Defines when the trigger fires. Use `flyte.Cron("expression")` for Cron-based scheduling or `flyte.FixedRate(interval_minutes, start_time=start_time)` for fixed intervals.
### Configuration Parameters
**`description: str = ""`**
Human-readable description of the trigger's purpose.
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Default parameter values for the task when triggered. Use `flyte.TriggerTime` as a value to inject the trigger execution timestamp into that parameter.
### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
Here's a comprehensive example showing all parameters:
```
comprehensive_trigger = flyte.Trigger(
name="monthly_financial_report",
automation=flyte.Cron("0 6 1 * *", timezone="America/New_York"),
description="Monthly financial report generation for executive team",
auto_activate=True,
inputs={
"report_date": flyte.TriggerTime,
"report_type": "executive_summary",
"include_forecasts": True
},
env_vars={
"REPORT_OUTPUT_FORMAT": "PDF",
"EMAIL_NOTIFICATIONS": "true"
},
interruptible=False, # Critical report, use dedicated resources
overwrite_cache=True, # Always fresh data
queue="financial-reports",
labels={
"team": "finance",
"criticality": "high",
"automation": "scheduled"
},
annotations={
"compliance.company.com/sox-required": "true",
"backup.company.com/retain-days": "2555" # 7 years
}
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## The `automation` parameter with `flyte.FixedRate`
You can define a fixed-rate schedule trigger by setting the `automation` parameter of the `flyte.Trigger` to an instance of `flyte.FixedRate`.
The `flyte.FixedRate` has the following signature:
```
flyte.FixedRate(
interval_minutes,
start_time=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Parameters
**`interval_minutes: int`** (required)
The interval between trigger executions in minutes.
**`start_time: datetime | None`**
When to start the fixed rate schedule. If not specified, starts when the trigger is deployed and activated.
### Examples
```
# Every 90 minutes, starting when deployed
every_90_min = flyte.Trigger(
"data_processing",
flyte.FixedRate(interval_minutes=90)
)
# Every 6 hours (360 minutes), starting at a specific time
specific_start = flyte.Trigger(
"batch_job",
flyte.FixedRate(
interval_minutes=360, # 6 hours
start_time=datetime(2025, 12, 1, 9, 0, 0) # Start Dec 1st at 9 AM
)
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## The `automation` parameter with `flyte.Cron`
You can define a Cron-based schedule trigger by setting the `automation` parameter to an instance of `flyte.Cron`.
The `flyte.Cron` has the following signature:
```
flyte.Cron(
cron_expression,
timezone=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Parameters
**`cron_expression: str`** (required)
The cron expression defining when the trigger should fire. Uses standard Unix cron format with five fields: minute, hour, day of month, month, and day of week.
**`timezone: str | None`**
The timezone for the cron expression. If not specified, it defaults to UTC. Uses standard timezone names like "America/New_York" or "Europe/London".
### Examples
```
# Every day at 6 AM UTC
daily_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 6 * * *")
)
# Every weekday at 9:30 AM Eastern Time
weekday_trigger = flyte.Trigger(
"business_hours_task",
flyte.Cron("30 9 * * 1-5", timezone="America/New_York")
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
#### Cron Expressions
Here are some common cron expressions you can use:
| Expression | Description |
|----------------|--------------------------------------|
| `0 0 * * *` | Every day at midnight |
| `0 9 * * 1-5` | Every weekday at 9 AM |
| `30 14 * * 6` | Every Saturday at 2:30 PM |
| `0 0 1 * *` | First day of every month at midnight |
| `0 0 25 * *` | 25th day of every month at midnight |
| `0 0 * * 0` | Every Sunday at midnight |
| `*/10 * * * *` | Every 10 minutes |
| `0 */2 * * *` | Every 2 hours |
For a full guide on Cron syntax, refer to [Crontab Guru](https://crontab.guru/).
## The `inputs` parameter
The `inputs` parameter allows you to provide default values for your task's parameters when the trigger fires.
This is essential for parameterizing your automated executions and passing trigger-specific data to your tasks.
### Basic Usage
```
trigger_with_inputs = flyte.Trigger(
"data_processing",
flyte.Cron("0 6 * * *"), # Daily at 6 AM
inputs={
"batch_size": 1000,
"environment": "production",
"debug_mode": False
}
)
@env.task(triggers=trigger_with_inputs)
def process_data(batch_size: int, environment: str, debug_mode: bool = True) -> str:
return f"Processing {batch_size} items in {environment} mode"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Using `flyte.TriggerTime`
The special `flyte.TriggerTime` value is used in the `inputs` to indicate the task parameter into which Flyte will inject the trigger execution timestamp:
```
timestamp_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 0 * * *"), # Daily at midnight
inputs={
"report_date": flyte.TriggerTime, # Receives trigger execution time
"report_type": "daily_summary"
}
)
@env.task(triggers=timestamp_trigger)
def generate_report(report_date: datetime, report_type: str) -> str:
return f"Generated {report_type} for {report_date.strftime('%Y-%m-%d')}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Required vs optional parameters
> [!IMPORTANT]
> If your task has parameters without default values, you **must** provide values for them in the trigger inputs, otherwise the trigger will fail to execute.
```python
# β This will fail - missing required parameter 'data_source'
bad_trigger = flyte.Trigger(
"bad_trigger",
flyte.Cron("0 0 * * *")
# Missing inputs for required parameter 'data_source'
)
@env.task(triggers=bad_trigger)
def bad_trigger_taska(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
# β This works - all required parameters provided
good_trigger = flyte.Trigger(
"good_trigger",
flyte.Cron("0 0 * * *"),
inputs={
"data_source": "prod_database", # Required parameter
"batch_size": 500 # Override default
}
)
@env.task(triggers=good_trigger)
def good_trigger_task(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
```
### Complex input types
You can pass various data types through trigger inputs:
```
complex_trigger = flyte.Trigger(
"ml_training",
flyte.Cron("0 2 * * 1"), # Weekly on Monday at 2 AM
inputs={
"model_config": {
"learning_rate": 0.01,
"batch_size": 32,
"epochs": 100
},
"feature_columns": ["age", "income", "location"],
"validation_split": 0.2,
"training_date": flyte.TriggerTime
}
)
@env.task(triggers=complex_trigger)
def train_model(
model_config: dict,
feature_columns: list[str],
validation_split: float,
training_date: datetime
) -> str:
return f"Training model with {len(feature_columns)} features on {training_date}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## Predefined schedule triggers
For common scheduling needs, Flyte provides predefined trigger methods that create Cron-based schedules without requiring you to specify cron expressions manually.
These are convenient shortcuts for frequently used scheduling patterns.
### Available Predefined Triggers
```
minutely_trigger = flyte.Trigger.minutely() # Every minute
hourly_trigger = flyte.Trigger.hourly() # Every hour
daily_trigger = flyte.Trigger.daily() # Every day at midnight
weekly_trigger = flyte.Trigger.weekly() # Every week (Sundays at midnight)
monthly_trigger = flyte.Trigger.monthly() # Every month (1st day at midnight)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
For reference, here's what each predefined trigger is equivalent to:
```python
# These are functionally identical:
flyte.Trigger.minutely() == flyte.Trigger("minutely", flyte.Cron("* * * * *"))
flyte.Trigger.hourly() == flyte.Trigger("hourly", flyte.Cron("0 * * * *"))
flyte.Trigger.daily() == flyte.Trigger("daily", flyte.Cron("0 0 * * *"))
flyte.Trigger.weekly() == flyte.Trigger("weekly", flyte.Cron("0 0 * * 0"))
flyte.Trigger.monthly() == flyte.Trigger("monthly", flyte.Cron("0 0 1 * *"))
```
### Predefined Trigger Parameters
All predefined trigger methods (`minutely()`, `hourly()`, `daily()`, `weekly()`, `monthly()`) accept the same set of parameters:
```
flyte.Trigger.daily(
trigger_time_input_key="trigger_time",
name="daily",
description="A trigger that runs daily at midnight",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
#### Core Parameters
**`trigger_time_input_key: str = "trigger_time"`**
The name of the task parameter that will receive the execution timestamp.
If no `trigger_time_input_key` is provided, the default is `trigger_time`.
In this case, if the task does not have a parameter named `trigger_time`, the task will still be executed, but, obviously, the timestamp will not be passed.
However, if you do specify a `trigger_time_input_key`, but your task does not actually have the specified parameter, an error will be raised at trigger deployment time.
**`name: str`**
The unique identifier for the trigger. Defaults to the method name (`"daily"`, `"hourly"`, etc.).
**`description: str`**
Human-readable description of the trigger's purpose. Each method has a sensible default.
#### Configuration Parameters
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Additional parameter values for your task when triggered. The `trigger_time_input_key` parameter is automatically included with `flyte.TriggerTime` as its value.
#### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
#### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
### Trigger time in predefined triggers
By default, predefined triggers will pass the execution time to the parameter `trigger_time` of type `datetime`,if that parameter exists on the task.
If no such parameter exists, the task will still be executed without error.
Optionally, you can customize the parameter name that receives the trigger execution timestamp by setting the `trigger_time_input_key` parameter (in this case the absence of this custom parameter on the task will raise an error at trigger deployment time):
```
@env.task(triggers=flyte.Trigger.daily(trigger_time_input_key="scheduled_at"))
def task_with_custom_trigger_time_input(scheduled_at: datetime) -> str:
return f"Executed at {scheduled_at}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## Multiple triggers per task
You can attach multiple triggers to a single task by providing a list of triggers. This allows you to run the same task on different schedules or with different configurations:
```
@env.task(triggers=[
flyte.Trigger.hourly(), # Predefined trigger
flyte.Trigger.daily(), # Another predefined trigger
flyte.Trigger("custom", flyte.Cron("0 */6 * * *")) # Custom trigger every 6 hours
])
def multi_trigger_task(trigger_time: datetime = flyte.TriggerTime) -> str:
# Different logic based on execution timing
if trigger_time.hour == 0: # Daily run at midnight
return f"Daily comprehensive processing at {trigger_time}"
else: # Hourly or custom runs
return f"Regular processing at {trigger_time.strftime('%H:%M')}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
You can mix and match trigger types, combining predefined triggers with those that use `flyte.Cron`, and `flyte.FixedRate` automations (see below for explanations of these concepts).
## Deploying a task with triggers
We recommend that you define your triggers in code together with your tasks and deploy them together.
The Union UI displays:
* `Owner` - who last deployed the trigger.
* `Last updated` - who last activated or deactivated the trigger and when. Note: If you deploy a trigger with `auto_activate=True`(default), this will match the `Owner`.
* `Last Run` - when was the last run created by this trigger.
For development and debugging purposes, you can adjust and deploy individual triggers from the UI.
To deploy a task with its triggers, you can either use Flyte CLI:
```shell
flyte deploy -p -d env
```
Or in Python:
```python
flyte.deploy(env)
```
Upon deploy, all triggers that are associated with a given task `T` will be automatically switched to apply to the latest version of that task. Triggers on task `T` which are defined elsewhere (i.e. in the UI) will be deleted unless they have been referenced in the task definition of `T`
## Activating and deactivating triggers
By default, triggers are automatically activated upon deployment (`auto_activate=True`).
Alternatively, you can set `auto_activate=False` to deploy inactive triggers.
An inactive trigger will not create runs until activated.
```
env = flyte.TaskEnvironment(name="my_task_env")
custom_cron_trigger = flyte.Trigger(
"custom_cron",
flyte.Cron("0 0 * * *"),
auto_activate=False # Dont create runs yet
)
@env.task(triggers=custom_cron_trigger)
def custom_task() -> str:
return "Hello, world!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
This trigger won't create runs until it is explicitly activated.
You can activate a trigger via the Flyte CLI:
```shell
flyte update trigger custom_cron my_task_env.custom_task --activate --project --domain
```
If you want to stop your trigger from creating new runs, you can deactivate it:
```shell
flyte update trigger custom_cron my_task_env.custom_task --deactivate --project --domain
```
You can also view and manage your deployed triggers in the Union UI.
## Trigger run timing
The timing of the first run created by a trigger depends on the type of trigger used (Cron-based or Fixed-rate) and whether the trigger is active upon deployment.
### Cron-based triggers
For Cron-based triggers, the first run will be created at the next scheduled time according to the cron expression after trigger activation and similarly thereafter.
* `0 0 * * *` If deployed at 17:00 today, the trigger will first fire 7 hours later (0:00 of the following day) and then every day at 0:00 thereafter.
* `*/15 14 * * 1-5` if today is Tuesday at 17:00, the trigger will fire the next day (Wednesday) at 14:00, 14:15, 14:30, and 14:45 and then the same for every subsequent weekday thereafter.
### Fixed-rate triggers without `start_time`
If no `start_time` is specified, then the first run will be created after the specified interval from the time of activation. No run will be created immediately upon activation, but the activation time will be used as the reference point for future runs.
#### No `start_time`, auto_activate: True
Let's say you define a fixed rate trigger with automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60))
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
In this case, the first run will occur 60 minutes after the successful deployment of the trigger.
So, if you deployed this trigger at 13:15, the first run will occur at 14:15 and so on thereafter.
#### No `start_time`, auto_activate: False
On the other hand, let's say you define a fixed rate trigger without automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60), auto_activate=False)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
Then you activate it after about 3 hours. In this case the first run will kick off 60 minutes after trigger activation.
If you deployed the trigger at 13:15 and activated it at 16:07, the first run will occur at 17:07.
### Fixed-rate triggers with `start_time`
If a `start_time` is specified, the timing of the first run depends on whether the trigger is active at `start_time` or not.
#### Fixed-rate with `start_time` while active
If a `start_time` is specified, and the trigger is active at `start_time` then the first run will occur at `start_time` and then at the specified interval thereafter.
For example:
```
my_trigger = flyte.Trigger(
"my_trigger",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
If you deploy this trigger on October 24th, 2025, the trigger will wait until October 26th 10:00am and will create the first run at exactly 10:00am.
#### Fixed-rate with `start_time` while inactive
If a start time is specified, but the trigger is activated after `start_time`, then the first run will be created when the next time point occurs that aligns with the recurring trigger interval using `start_time` as the initial reference point.
For example:
```
custom_rate_trigger = flyte.Trigger(
"custom_rate",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
auto_activate=False
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
If activated later than the `start_time`, say on October 28th 12:35pm for example, the first run will be created at October 28th at 1:00pm.
## Deleting triggers
If you decide that you don't need a trigger anymore, you can remove the trigger from the task definition and deploy the task again.
Alternatively, you can use Flyte CLI:
```shell
flyte delete trigger custom_cron my_task_env.custom_task --project --domain
```
## Schedule time zones
### Setting time zone for a Cron schedule
Cron expressions are by default in UTC, but it's possible to specify custom time zones like so:
```
sf_trigger = flyte.Trigger(
"sf_tz",
flyte.Cron(
"0 9 * * *", timezone="America/Los_Angeles"
), # Every day at 9 AM PT
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
nyc_trigger = flyte.Trigger(
"nyc_tz",
flyte.Cron(
"1 12 * * *", timezone="America/New_York"
), # Every day at 12:01 PM ET
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
The above two schedules will fire 1 minute apart, at 9 AM PT and 12:01 PM ET respectively.
### `flyte.TriggerTime` is always in UTC
The `flyte.TriggerTime` value is always in UTC. For timezone-aware logic, convert as needed:
```
@env.task(triggers=flyte.Trigger.minutely(trigger_time_input_key="utc_trigger_time", name="timezone_trigger"))
def timezone_task(utc_trigger_time: datetime) -> str:
local_time = utc_trigger_time.replace(tzinfo=timezone.utc).astimezone()
return f"Task fired at {utc_trigger_time} UTC ({local_time} local)"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Daylight Savings Time behavior
When Daylight Savings Time (DST) begins and ends, it can impact when the scheduled execution begins.
On the day DST begins, time jumps from 2:00AM to 3:00AM, which means the time of 2:30AM won't exist. In this case, the trigger will not fire until the next 2:30AM, which is the next day.
On the day DST ends, the hour from 1:00AM to 2:00AM repeats, which means the time of 1:30AM will exist twice. If the schedule above was instead set for 1:30AM, it would only run once, on the first occurrence of 1:30AM.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming ===
# Build tasks
This section covers the essential programming patterns and techniques for developing robust Flyte workflows. Once you understand the basics of task configuration, these guides will help you build sophisticated, production-ready data pipelines and machine learning workflows.
## What you'll learn
The task programming section covers key patterns for building effective Flyte workflows:
**Data handling and types**
- **Build tasks > Files and directories**: Work with large datasets using Flyte's efficient file and directory types that automatically handle data upload, storage, and transfer between tasks.
- **Build tasks > Data classes and structures**: Use Python data classes and Pydantic models as task inputs and outputs to create well-structured, type-safe workflows.
- **Build tasks > Custom context**: Use custom context to pass metadata through your task execution hierarchy without adding parameters to every task.
**Execution patterns**
- **Build tasks > Fanout**: Scale your workflows by running many tasks in parallel, perfect for processing large datasets or running hyperparameter sweeps.
- **Build tasks > Grouping actions**: Organize related task executions into logical groups for better visualization and management in the UI.
**Development and debugging**
- **Build tasks > Notebooks**: Write and iterate on workflows directly in Jupyter notebooks for interactive development and experimentation.
- **Build tasks > Reports**: Generate custom HTML reports during task execution to display progress, results, and visualizations in the UI.
- **Build tasks > Traces**: Add fine-grained observability to helper functions within your tasks for better debugging and resumption capabilities.
- **Build tasks > Error handling**: Implement robust error recovery strategies, including automatic resource scaling and graceful failure handling.
## When to use these patterns
These programming patterns become essential as your workflows grow in complexity:
- Use **fanout** when you need to process multiple items concurrently or run parameter sweeps.
- Implement **error handling** for production workflows that need to recover from infrastructure failures.
- Apply **grouping** to organize complex workflows with many task executions.
- Leverage **files and directories** when working with large datasets that don't fit in memory.
- Use **traces** to debug non-deterministic operations like API calls or ML inference.
- Create **reports** to monitor long-running workflows and share results with stakeholders.
- Use **custom context** when you need lightweight, cross-cutting metadata to flow through your task hierarchy without becoming part of the taskβs logical inputs.
Each guide includes practical examples and best practices to help you implement these patterns effectively in your own workflows.
## Subpages
- **Build tasks > Data classes and structures**
- **Build tasks > DataFrames**
- **Build tasks > Files and directories**
- **Build tasks > Custom context**
- **Build tasks > Reports**
- **Build tasks > Notebooks**
- **Build tasks > Remote tasks**
- **Build tasks > Error handling**
- **Build tasks > Traces**
- **Build tasks > Grouping actions**
- **Build tasks > Fanout**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/dataclasses-and-structures ===
# Data classes and structures
Dataclasses and Pydantic models are fully supported in Flyte as **materialized data types**:
Structured data where the full content is serialized and passed between tasks.
Use these as you would normally, passing them as inputs and outputs of tasks.
Unlike **offloaded types** like **Build tasks > DataFrames**, **Build tasks > Files and directories**, data class and Pydantic model data is fully serialized, stored, and deserialized between tasks.
This makes them ideal for configuration objects, metadata, and smaller structured data where all fields should be serializable.
## Example: Combining Dataclasses and Pydantic Models
This example demonstrates how data classes and Pydantic models work together as materialized data types, showing nested structures and batch processing patterns:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from dataclasses import dataclass
from typing import List
from pydantic import BaseModel
import flyte
env = flyte.TaskEnvironment(name="ex-mixed-structures")
@dataclass
class InferenceRequest:
feature_a: float
feature_b: float
@dataclass
class BatchRequest:
requests: List[InferenceRequest]
batch_id: str = "default"
class PredictionSummary(BaseModel):
predictions: List[float]
average: float
count: int
batch_id: str
@env.task
async def predict_one(request: InferenceRequest) -> float:
"""
A dummy linear model: prediction = 2 * feature_a + 3 * feature_b + bias(=1.0)
"""
return 2.0 * request.feature_a + 3.0 * request.feature_b + 1.0
@env.task
async def process_batch(batch: BatchRequest) -> PredictionSummary:
"""
Processes a batch of inference requests and returns summary statistics.
"""
# Process all requests concurrently
tasks = [predict_one(request=req) for req in batch.requests]
predictions = await asyncio.gather(*tasks)
# Calculate statistics
average = sum(predictions) / len(predictions) if predictions else 0.0
return PredictionSummary(
predictions=predictions,
average=average,
count=len(predictions),
batch_id=batch.batch_id
)
@env.task
async def summarize_results(summary: PredictionSummary) -> str:
"""
Creates a text summary from the prediction results.
"""
return (
f"Batch {summary.batch_id}: "
f"Processed {summary.count} predictions, "
f"average value: {summary.average:.2f}"
)
@env.task
async def main() -> str:
batch = BatchRequest(
requests=[
InferenceRequest(feature_a=1.0, feature_b=2.0),
InferenceRequest(feature_a=3.0, feature_b=4.0),
InferenceRequest(feature_a=5.0, feature_b=6.0),
],
batch_id="demo_batch_001"
)
summary = await process_batch(batch)
result = await summarize_results(summary)
return result
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataclasses-and-structures/example.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/dataframes ===
# DataFrames
By default, return values in Python are materialized - meaning the actual data is downloaded and loaded into memory. This applies to simple types like integers, as well as more complex types like DataFrames.
To avoid downloading large datasets into memory, Flyte V2 exposes **Build tasks > DataFrames > `flyte.io.dataframe`**: a thin, uniform wrapper type for DataFrame-style objects that allows you to pass a reference to the data, rather than the fully materialized contents.
The `flyte.io.DataFrame` type provides serialization support for common engines like `pandas`, `polars`, `pyarrow`, `dask`, etc.; enabling you to move data between different DataFrame backends.
## Setting up the environment and sample data
For our example we will start by setting up our task environment with the required dependencies and create some sample data.
```
from typing import Annotated
import numpy as np
import pandas as pd
import flyte
import flyte.io
env = flyte.TaskEnvironment(
"dataframe_usage",
image= flyte.Image.from_debian_base().with_pip_packages("pandas", "pyarrow", "numpy"),
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
BASIC_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"name": ["Alice", "Bob", "Charlie", "Diana", "Ethan", "Fiona", "George", "Hannah"],
"department": ["HR", "Engineering", "Engineering", "Marketing", "Finance", "Finance", "HR", "Engineering"],
"hire_date": pd.to_datetime(
["2018-01-15", "2019-03-22", "2020-07-10", "2017-11-01", "2021-06-05", "2018-09-13", "2022-01-07", "2020-12-30"]
),
}
ADDL_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"salary": [55000, 75000, 72000, 50000, 68000, 70000, np.nan, 80000],
"bonus_pct": [0.05, 0.10, 0.07, 0.04, np.nan, 0.08, 0.03, 0.09],
"full_time": [True, True, True, False, True, True, False, True],
"projects": [
["Recruiting", "Onboarding"],
["Platform", "API"],
["API", "Data Pipeline"],
["SEO", "Ads"],
["Budget", "Forecasting"],
["Auditing"],
[],
["Platform", "Security", "Data Pipeline"],
],
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
## Create a raw DataFrame
Now, let's create a task that returns a native Pandas DataFrame:
```
@env.task
async def create_raw_dataframe() -> pd.DataFrame:
return pd.DataFrame(BASIC_EMPLOYEE_DATA)
# {{docs-fragment from-df}}
@env.task
async def create_flyte_dataframe() -> Annotated[flyte.io.DataFrame, "parquet"]:
pd_df = pd.DataFrame(ADDL_EMPLOYEE_DATA)
fdf = flyte.io.DataFrame.from_df(pd_df)
return fdf
# {{/docs-fragment from-df}}
# {{docs-fragment automatic}}
@env.task
async def join_data(raw_dataframe: pd.DataFrame, flyte_dataframe: pd.DataFrame) -> flyte.io.DataFrame:
joined_df = raw_dataframe.merge(flyte_dataframe, on="employee_id", how="inner")
return flyte.io.DataFrame.from_df(joined_df)
# {{/docs-fragment automatic}}
# {{docs-fragment download}}
@env.task
async def download_data(joined_df: flyte.io.DataFrame):
downloaded = await joined_df.open(pd.DataFrame).all()
print("Downloaded Data:\n", downloaded)
# {{/docs-fragment download}}
# {{docs-fragment main}}
@env.task
async def main():
raw_df = await create_raw_dataframe ()
flyte_df = await create_flyte_dataframe ()
joined_df = await join_data (raw_df, flyte_df)
await download_data (joined_df)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
This is the most basic use-case of how to pass DataFrames (of all kinds, not just Pandas).
We simply create the DataFrame as normal, and return it.
Because the task has been declared to return a supported native DataFrame type (in this case `pandas.DataFrame` Flyte will automatically detect it, serialize it correctly and upload it at task completion enabling it to be passed transparently to the next task.
Flyte supports auto-serialization for the following DataFrame types:
* `pandas.DataFrame`
* `pyarrow.Table`
* `dask.dataframe.DataFrame`
* `polars.DataFrame`
* `flyte.io.DataFrame` (see below)
## Create a flyte.io.DataFrame
Alternatively you can also create a `flyte.io.DataFrame` object directly from a native object with the `from_df` method:
```
@env.task
async def create_flyte_dataframe() -> Annotated[flyte.io.DataFrame, "parquet"]:
pd_df = pd.DataFrame(ADDL_EMPLOYEE_DATA)
fdf = flyte.io.DataFrame.from_df(pd_df)
return fdf
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
The `flyte.io.DataFrame` class creates a thin wrapper around objects of any standard DataFrame type. It serves as a generic "any DataFrame type" (a concept that Python itself does not currently offer).
As with native DataFrame types, Flyte will automatically serialize and upload the data at task completion.
The advantage of the unified `flyte.io.DataFrame` wrapper is that you can be explicit about the storage format that makes sense for your use case, by using an `Annotated` type where the second argument encodes format or other lightweight hints. For example, here we specify that the DataFrame should be stored as Parquet:
## Automatically convert between types
You can leverage Flyte to automatically download and convert the DataFrame between types when needed:
```
@env.task
async def join_data(raw_dataframe: pd.DataFrame, flyte_dataframe: pd.DataFrame) -> flyte.io.DataFrame:
joined_df = raw_dataframe.merge(flyte_dataframe, on="employee_id", how="inner")
return flyte.io.DataFrame.from_df(joined_df)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
This task takes two DataFrames as input. We'll pass one raw Pandas DataFrame, and one `flyte.io.DataFrame`.
Flyte automatically converts the `flyte.io.DataFrame` to a Pandas DataFrame (since we declared that as the input type) before passing it to the task.
The actual download and conversion happens only when we access the data, in this case, when we do the merge.
## Downloading DataFrames
When a task receives a `flyte.io.DataFrame`, you can request a concrete backend representation. For example, to download as a pandas DataFrame:
```
@env.task
async def download_data(joined_df: flyte.io.DataFrame):
downloaded = await joined_df.open(pd.DataFrame).all()
print("Downloaded Data:\n", downloaded)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
The `open()` call delegates to the DataFrame handler for the stored format and converts to the requested in-memory type.
## Run the example
Finally, we can define a `main` function to run the tasks defined above and a `__main__` block to execute the workflow:
```
@env.task
async def main():
raw_df = await create_raw_dataframe ()
flyte_df = await create_flyte_dataframe ()
joined_df = await join_data (raw_df, flyte_df)
await download_data (joined_df)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/files-and-directories ===
# Files and directories
Flyte provides the **Build tasks > Files and directories > `flyte.io.File`** and
**Build tasks > Files and directories > `flyte.io.Dir`** types to represent files and directories, respectively.
Together with **Build tasks > DataFrames** they constitute the *offloaded data types* - unlike **Build tasks > Data classes and structures** like data classes, these pass references rather than full data content.
A variable of an offloaded type does not contain its actual data, but rather a reference to the data.
The actual data is stored in the internal blob store of your Union/Flyte instance.
When a variable of an offloaded type is first created, its data is uploaded to the blob store.
It can then be passed from task to task as a reference.
The actual data is only downloaded from the blob stored when the task needs to access it, for example, when the task calls `open()` on a `File` or `Dir` object.
This allows Flyte to efficiently handle large files and directories without needing to transfer the data unnecessarily.
Even very large data objects like video files and DNA datasets can be passed efficiently between tasks.
The `File` and `Dir` classes provide both `sync` and `async` methods to interact with the data.
## Example usage
The examples below show the basic use-cases of uploading files and directories created locally, and using them as inputs to a task.
```
import asyncio
import tempfile
from pathlib import Path
import flyte
from flyte.io import Dir, File
env = flyte.TaskEnvironment(name="files-and-folders")
@env.task
async def write_file(name: str) -> File:
# Create a file and write some content to it
with open("test.txt", "w") as f:
f.write(f"hello world {name}")
# Upload the file using flyte
uploaded_file_obj = await File.from_local("test.txt")
return uploaded_file_obj
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
The upload happens when the **Build tasks > Files and directories > `File.from_local`** command is called.
Because the upload would otherwise block execution, `File.from_local` is implemented as an `async` function.
The Flyte SDK frequently uses this class constructor pattern, so you will see it with other types as well.
This is a slightly more complicated task that calls the task above to produce `File` objects.
These are assembled into a directory and the `Dir` object is returned, also via invoking `from_local`.
```
@env.task
async def write_and_check_files() -> Dir:
coros = []
for name in ["Alice", "Bob", "Eve"]:
coros.append(write_file(name=name))
vals = await asyncio.gather(*coros)
temp_dir = tempfile.mkdtemp()
for file in vals:
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"File {file.path} contents: {contents_str}")
new_file = Path(temp_dir) / file.name
with open(new_file, "w") as out: # noqa: ASYNC230
out.write(contents_str)
print(f"Files written to {temp_dir}")
# walk the directory and ls
for path in Path(temp_dir).iterdir():
print(f"File: {path.name}")
my_dir = await Dir.from_local(temp_dir)
return my_dir
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
Finally, these tasks show how to use an offloaded type as an input.
Helper functions like `walk` and `open` have been added to the objects
and do what you might expect.
```
@env.task
async def check_dir(my_dir: Dir):
print(f"Dir {my_dir.path} contents:")
async for file in my_dir.walk():
print(f"File: {file.name}")
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"Contents: {contents_str}")
@env.task
async def create_and_check_dir():
my_dir = await write_and_check_files()
await check_dir(my_dir=my_dir)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(create_and_check_dir)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/custom-context ===
# Custom context
Custom context provides a mechanism for implicitly passing configuration and metadata through your entire task execution hierarchy without adding parameters to every task. It is ideal for cross-cutting concerns such as tracing, environment metadata, or experiment identifiers.
Think of custom context as **execution-scoped metadata** that automatically flows from parent to child tasks.
## Overview
Custom context is an implicit keyβvalue configuration map that is automatically available to tasks during execution. It is stored in the blob store of your Union/Flyte instance together with the taskβs inputs, making it available across tasks without needing to pass it explicitly.
You can access it in a Flyte task via:
```python
flyte.ctx().custom_context
```
Custom context is fundamentally different from standard task inputs. Task inputs are explicit, strongly typed parameters that you declare as part of a taskβs signature. They directly influence the taskβs computation and therefore participate in Flyteβs caching and reproducibility guarantees.
Custom context, on the other hand, is implicit metadata. It consists only of string key/value pairs, is not part of the task signature, and does not affect task caching. Because it is injected by the Flyte runtime rather than passed as a formal input, it should be used only for environmental or contextual information, not for data that changes the logical output of a task.
## When to use it and when not to
Custom Context is perfect when you need metadata, not domain data, to flow through your tasks.
Good use cases:
- Tracing IDs, span IDs
- Experiment or run metadata
- Environment region, cluster ID
- Logging correlation keys
- Feature flags
- Session IDs for 3rd-party APIs (e.g., an LLM session)
Avoid using for:
- Business/domain data
- Inputs that change task outputs
- Anything affecting caching or reproducibility
- Large blobs of data (keep it small)
It is the cleanest mechanism when you need something available everywhere, but not logically an input to the computation.
## Setting custom context
There are two ways to set custom context for a Flyte run:
1. Set it once for the entire run when you launch (`with_runcontext`) β this establishes the base context for the execution
2. Set or override it inside task code using `flyte.custom_context(...)` context manager β this changes the active context for that task block and any nested tasks called from it
Both are legitimate and complementary. The important behavioral rules to understand are:
- `with_runcontext(...)` sets the run-level base. Values provided here are available everywhere unless overridden later. Use this for metadata that should apply to most or all tasks in the run (experiment name, top-level trace id, run id, etc.).
- `flyte.custom_context(...)` is used inside task code to set or override values for that scope. It does affect nested tasks invoked while that context is active. In practice this means you can override run-level entries, add new keys for downstream tasks, or both.
- Merging & precedence: contexts are merged; when the same key appears in multiple places the most recent/innermost value wins (i.e., values set by `flyte.custom_context(...)` override the run-level values from `with_runcontext(...)` for the duration of that block).
### Run-level context
Set base metadata once when starting the run:
```
import flyte
env = flyte.TaskEnvironment("custom-context-example")
@env.task
async def leaf_task() -> str:
# Reads run-level context
print("leaf sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def root() -> str:
return await leaf_task()
if __name__ == "__main__":
flyte.init_from_config()
# Base context for the entire run
flyte.with_runcontext(custom_context={"trace_id": "root-abc", "experiment": "v1"}).run(root)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/run_context.py)
Output (every task sees the base keys unless overridden):
```bash
leaf sees: {"trace_id": "root-abc", "experiment": "v1"}
```
### Overriding inside a task (local override that affects nested tasks)
Use `flyte.custom_context(...)` inside a task to override or add keys for downstream calls:
```
@env.task
async def downstream() -> str:
print("downstream sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def parent() -> str:
print("parent initial:", flyte.ctx().custom_context)
# Override the trace_id for the nested call(s)
with flyte.custom_context(trace_id="child-override"):
val = await downstream() # downstream sees trace_id="child-override"
# After the context block, run-level values are back
print("parent after:", flyte.ctx().custom_context)
return val
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/override_context.py)
If the run was started with `{"trace_id": "root-abc"}`, this prints:
```bash
parent initial: {"trace_id": "root-abc"}
downstream sees: {"trace_id": "child-override"}
parent after: {"trace_id": "root-abc"}
```
Note that the override affected the nested downstream task because it was invoked while the `flyte.custom_context` block was active.
### Adding new keys for nested tasks
You can add keys (not just override):
```python
with flyte.custom_context(experiment="exp-blue", run_group="g-7"):
await some_task() # some_task sees both base keys + the new keys
```
## Accessing custom context
Always via the Flyte runtime:
```python
ctx = flyte.ctx().custom_context
value = ctx.get("key")
```
You can access the custom context using either `flyte.ctx().custom_context` or the shorthand `flyte.get_custom_context()`, which returns the same dictionary of key/value pairs.
Values are always strings, so parse as needed:
```python
timeout = int(ctx["timeout_seconds"])
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/reports ===
# Reports
The reports feature allows you to display and update custom output in the UI during task execution.
First, you set the `report=True` flag in the task decorator. This enables the reporting feature for that task.
Within a task with reporting enabled, a **Build tasks > Reports > `flyte.report.Report`** object is created automatically.
A `Report` object contains one or more tabs, each of which contains HTML.
You can write HTML to an existing tab and create new tabs to organize your content.
Initially, the `Report` object has one tab (the default tab) with no content.
To write content:
- **Flyte SDK > Packages > flyte.report > Methods > log()** appends HTML content directly to the default tab.
- **Flyte SDK > Packages > flyte.report > Methods > replace()** replaces the content of the default tab with new HTML.
To get or create a new tab:
- **Build tasks > Reports > `flyte.report.get_tab()`** allows you to specify a unique name for the tab, and it will return the existing tab if it already exists or create a new one if it doesn't.
It returns a `flyte.report._report.Tab`
You can `log()` or `replace()` HTML on the `Tab` object just as you can directly on the `Report` object.
Finally, you send the report to the Flyte server and make it visible in the UI:
- **Flyte SDK > Packages > flyte.report > Methods > flush()** dispatches the report.
**It is important to call this method to ensure that the data is sent**.
## A simple example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = ""
# ///
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="reports_example")
@env.task(report=True)
async def task1():
await flyte.report.replace.aio("
")
await flyte.report.flush.aio()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(task1)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/simple.py)
Here we define a task `task1` that logs some HTML content to the default tab and creates a new tab named "Tab 2" where it logs additional HTML content.
The `flush` method is called to send the report to the backend.
## A more complex example
Here is another example.
We import the necessary modules, set up the task environment, define the main task with reporting enabled and define the data generation function:
```
import json
import random
import flyte
import flyte.report
env = flyte.TaskEnvironment(
name="globe_visualization",
)
@env.task(report=True)
async def generate_globe_visualization():
await flyte.report.replace.aio(get_html_content())
await flyte.report.flush.aio()
def generate_globe_data():
"""Generate sample data points for the globe"""
cities = [
{"city": "New York", "country": "USA", "lat": 40.7128, "lng": -74.0060},
{"city": "London", "country": "UK", "lat": 51.5074, "lng": -0.1278},
{"city": "Tokyo", "country": "Japan", "lat": 35.6762, "lng": 139.6503},
{"city": "Sydney", "country": "Australia", "lat": -33.8688, "lng": 151.2093},
{"city": "Paris", "country": "France", "lat": 48.8566, "lng": 2.3522},
{"city": "SΓ£o Paulo", "country": "Brazil", "lat": -23.5505, "lng": -46.6333},
{"city": "Mumbai", "country": "India", "lat": 19.0760, "lng": 72.8777},
{"city": "Cairo", "country": "Egypt", "lat": 30.0444, "lng": 31.2357},
{"city": "Moscow", "country": "Russia", "lat": 55.7558, "lng": 37.6176},
{"city": "Beijing", "country": "China", "lat": 39.9042, "lng": 116.4074},
{"city": "Lagos", "country": "Nigeria", "lat": 6.5244, "lng": 3.3792},
{"city": "Mexico City", "country": "Mexico", "lat": 19.4326, "lng": -99.1332},
{"city": "Bangkok", "country": "Thailand", "lat": 13.7563, "lng": 100.5018},
{"city": "Istanbul", "country": "Turkey", "lat": 41.0082, "lng": 28.9784},
{"city": "Buenos Aires", "country": "Argentina", "lat": -34.6118, "lng": -58.3960},
{"city": "Cape Town", "country": "South Africa", "lat": -33.9249, "lng": 18.4241},
{"city": "Dubai", "country": "UAE", "lat": 25.2048, "lng": 55.2708},
{"city": "Singapore", "country": "Singapore", "lat": 1.3521, "lng": 103.8198},
{"city": "Stockholm", "country": "Sweden", "lat": 59.3293, "lng": 18.0686},
{"city": "Vancouver", "country": "Canada", "lat": 49.2827, "lng": -123.1207},
]
categories = ["high", "medium", "low", "special"]
data_points = []
for city in cities:
data_point = {**city, "value": random.randint(10, 100), "category": random.choice(categories)}
data_points.append(data_point)
return data_points
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)
We then define the HTML content for the report:
```python
def get_html_content():
data_points = generate_globe_data()
html_content = f"""
...
return html_content
"""
```
(We exclude it here due to length. You can find it in the [source file](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)).
Finally, we run the workflow:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(generate_globe_visualization)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)
When the workflow runs, the report will be visible in the UI:

## Streaming example
Above we demonstrated reports that are sent to the UI once, at the end of the task execution.
But, you can also stream updates to the report during task execution and see the display update in real-time.
You do this by calling `flyte.report.flush()` (or specifying `do_flush=True` in `flyte.report.log()`) periodically during the task execution, instead of just at the end of the task execution
> [!NOTE]
> In the above examples we explicitly call `flyte.report.flush()` to send the report to the UI.
> In fact, this is optional since flush will be called automatically at the end of the task execution.
> For streaming reports, on the other hand, calling `flush()` periodically (or specifying `do_flush=True`
> in `flyte.report.log()`) is necessary to display the updates.
First we import the necessary modules, and set up the task environment:
```
import asyncio
import json
import math
import random
import time
from datetime import datetime
from typing import List
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="streaming_reports")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)
Next we define the HTML content for the report:
```python
DATA_PROCESSING_DASHBOARD_HTML = """
...
"""
```
(We exclude it here due to length. You can find it in the [source file](
https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)).
Finally, we define the task that renders the report (`data_processing_dashboard`), the driver task of the workflow (`main`), and the run logic:
```
@env.task(report=True)
async def data_processing_dashboard(total_records: int = 50000) -> str:
"""
Simulates a data processing pipeline with real-time progress visualization.
Updates every second for approximately 1 minute.
"""
await flyte.report.log.aio(DATA_PROCESSING_DASHBOARD_HTML, do_flush=True)
# Simulate data processing
processed = 0
errors = 0
batch_sizes = [800, 850, 900, 950, 1000, 1050, 1100] # Variable processing rates
start_time = time.time()
while processed < total_records:
# Simulate variable processing speed
batch_size = random.choice(batch_sizes)
# Add some processing delays occasionally
if random.random() < 0.1: # 10% chance of slower batch
batch_size = int(batch_size * 0.6)
await flyte.report.log.aio("""
""", do_flush=True)
elif random.random() < 0.05: # 5% chance of error
errors += random.randint(1, 5)
await flyte.report.log.aio("""
""", do_flush=True)
else:
await flyte.report.log.aio(f"""
""", do_flush=True)
processed = min(processed + batch_size, total_records)
current_time = time.time()
elapsed = current_time - start_time
rate = int(batch_size) if elapsed < 1 else int(processed / elapsed)
success_rate = ((processed - errors) / processed) * 100 if processed > 0 else 100
# Update dashboard
await flyte.report.log.aio(f"""
""", do_flush=True)
print(f"Processed {processed:,} records, Errors: {errors}, Rate: {rate:,}"
f" records/sec, Success Rate: {success_rate:.2f}%", flush=True)
await asyncio.sleep(1) # Update every second
if processed >= total_records:
break
# Final completion message
total_time = time.time() - start_time
avg_rate = int(total_records / total_time)
await flyte.report.log.aio(f"""
π Processing Complete!
Total Records: {total_records:,}
Processing Time: {total_time:.1f} seconds
Average Rate: {avg_rate:,} records/second
Success Rate: {success_rate:.2f}%
Errors Handled: {errors}
""", do_flush=True)
print(f"Data processing completed: {processed:,} records processed with {errors} errors.", flush=True)
return f"Processed {total_records:,} records successfully"
@env.task
async def main():
"""
Main task to run both reports.
"""
await data_processing_dashboard(total_records=50000)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)
The key to the live update ability is the `while` loop that appends Javascript to the report. The Javascript calls execute on append to the document and update it.
When the workflow runs, you can see the report updating in real-time in the UI:

=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/notebooks ===
# Notebooks
Flyte is designed to work seamlessly with Jupyter notebooks, allowing you to write and execute workflows directly within a notebook environment.
## Iterating on and running a workflow
Download the following notebook file and open it in your favorite Jupyter environment: [interactive.ipynb](../../_static/public/interactive.ipynb)
In this example we have a simple workflow defined in our notebook.
You can iterate on the code in the notebook while running each cell in turn.
Note that the **Flyte SDK > Packages > flyte > Methods > init()** call at the top of the notebook looks like this:
```python
flyte.init(
endpoint="https://union.example.com",
org="example_org",
project="example_project",
domain="development",
)
```
You will have to adjust it to match your Union server endpoint, organization, project, and domain.
## Accessing runs and downloading logs
Similarly, you can download the following notebook file and open it in your favorite Jupyter environment: [remote.ipynb](../../_static/public/remote.ipynb)
In this example we use the **Flyte SDK > Packages > flyte.remote** package to list existing runs, access them, and download their details and logs.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/remote-tasks ===
# Remote tasks
Remote tasks let you use previously deployed tasks without importing their code or dependencies. This enables teams to share and reuse tasks without managing complex dependency chains or container images.
## Prerequisites
Remote tasks must be deployed before you can use them. See the **Run and deploy tasks** for details.
## Basic usage
Use `flyte.remote.Task.get()` to reference a deployed task:
```python
import flyte
import flyte.remote
env = flyte.TaskEnvironment(name="my_env")
# Get the latest version of a deployed task
data_processor = flyte.remote.Task.get(
"data_team.spark_analyzer",
auto_version="latest"
)
# Use it in your task
@env.task
async def my_task(data_path: str) -> flyte.io.DataFrame:
# Call the reference task like any other task
result = await data_processor(input_path=data_path)
return result
```
You can run this directly without deploying it:
```bash
flyte run my_workflow.py my_task --data_path s3://my-bucket/data.parquet
```
## Complete example
This example shows how different teams can collaborate using remote tasks.
### Team A: Spark environment
Team A maintains Spark-based data processing tasks:
```python
# spark_env.py
from dataclasses import dataclass
import flyte
env = flyte.TaskEnvironment(name="spark_env")
@dataclass
class AnalysisResult:
mean_value: float
std_dev: float
@env.task
async def analyze_data(data_path: str) -> AnalysisResult:
# Spark code here (not shown)
return AnalysisResult(mean_value=42.5, std_dev=3.2)
@env.task
async def compute_score(result: AnalysisResult) -> float:
# More Spark processing
return result.mean_value / result.std_dev
```
Deploy the Spark environment:
```bash
flyte deploy spark_env/
```
### Team B: ML environment
Team B maintains PyTorch-based ML tasks:
```python
# ml_env.py
from pydantic import BaseModel
import flyte
env = flyte.TaskEnvironment(name="ml_env")
class PredictionRequest(BaseModel):
feature_x: float
feature_y: float
class Prediction(BaseModel):
score: float
confidence: float
model_version: str
@env.task
async def run_inference(request: PredictionRequest) -> Prediction:
# PyTorch model inference (not shown)
return Prediction(
score=request.feature_x * 2.5,
confidence=0.95,
model_version="v2.1"
)
```
Deploy the ML environment:
```bash
flyte deploy ml_env/
```
### Team C: Orchestration
Team C builds a workflow using remote tasks from both teams without needing Spark or PyTorch dependencies:
```python
# orchestration_env.py
import flyte.remote
env = flyte.TaskEnvironment(name="orchestration")
# Reference tasks from other teams
analyze_data = flyte.remote.Task.get(
"spark_env.analyze_data",
auto_version="latest"
)
compute_score = flyte.remote.Task.get(
"spark_env.compute_score",
auto_version="latest"
)
run_inference = flyte.remote.Task.get(
"ml_env.run_inference",
auto_version="latest"
)
@env.task
async def orchestrate_pipeline(data_path: str) -> float:
# Use Spark tasks without Spark dependencies
analysis = await analyze_data(data_path=data_path)
# Access attributes from the result
# (Flyte creates a fake type that allows attribute access)
print(f"Analysis: mean={analysis.mean_value}, std={analysis.std_dev}")
data_score = await compute_score(result=analysis)
# Use ML task without PyTorch dependencies
# Pass Pydantic models as dictionaries
prediction = await run_inference(
request={
"feature_x": analysis.mean_value,
"feature_y": data_score
}
)
# Access Pydantic model attributes
print(f"Prediction: {prediction.score} (confidence: {prediction.confidence})")
return prediction.score
```
Run the orchestration task directly (no deployment needed):
**Using Python API**:
```python
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(
orchestrate_pipeline,
data_path="s3://my-bucket/data.parquet"
)
print(f"Execution URL: {run.url}")
# You can wait for the execution
run.wait()
# You can then retrieve the outputs
print(f"Pipeline result: {run.outputs()}")
```
**Using CLI**:
```bash
flyte run orchestration_env.py orchestrate_pipeline --data_path s3://my-bucket/data.parquet
```
## Invoke remote tasks in a script.
You can also run any remote task directly using a script in a similar way
```python
import flyte
import flyte.models
import flyte.remote
flyte.init_from_config()
# Fetch the task
remote_task = flyte.remote.Task.get("package-example.calculate_average", auto_version="latest")
# Create a run, note keyword arguments are required currently. In the future this will accept positional args based on the declaration order, but, we still recommend to use keyword args.
run = flyte.run(remote_task, numbers=[1.0, 2.0, 3.0])
print(f"Execution URL: {run.url}")
# you can view the phase
print(f"Current Phase: {run.phase}")
# You can wait for the execution
run.wait()
# Only available after flyte >= 2.0.0b39
print(f"Current phase: {run.phase}")
# Phases can be compared to
if run.phase == flyte.models.ActionPhase.SUCCEEDED:
print(f"Run completed!")
# You can then retrieve the outputs
print(f"Pipeline result: {run.outputs()}")
```
## Why use remote tasks?
Remote tasks solve common collaboration and dependency management challenges:
**Cross-team collaboration**: Team A has deployed a Spark task that analyzes large datasets. Team B needs this analysis for their ML pipeline but doesn't want to learn Spark internals, install Spark dependencies, or build Spark-enabled container images. With remote tasks, Team B simply references Team A's deployed task.
**Platform reusability**: Platform teams can create common, reusable tasks (data validation, feature engineering, model serving) that other teams can use without duplicating code or managing complex dependencies.
**Microservices for data workflows**: Remote tasks work like microservices for long-running tasks or agents, enabling secure sharing while maintaining isolation.
## When to use remote tasks
Use remote tasks when you need to:
- Use functionality from another team without their dependencies
- Share common tasks across your organization
- Build reusable platform components
- Avoid dependency conflicts between different parts of your workflow
- Create modular, maintainable data pipelines
## How remote tasks work
### Security model
Remote tasks run in the **caller's project and domain** using the caller's compute resources, but execute with the **callee's service accounts, IAM roles, and secrets**. This ensures:
- Tasks are secure from misuse
- Resource usage is properly attributed
- Authentication and authorization are maintained
- Collaboration remains safe and controlled
### Type system
Remote tasks use Flyte's default types as inputs and outputs. Flyte's type system seamlessly translates data between tasks without requiring the original dependencies:
| Remote Task Type | Flyte Type |
|-------------------|------------|
| DataFrames (`pandas`, `polars`, `spark`, etc.) | `flyte.io.DataFrame` |
| Object store files | `flyte.io.File` |
| Object store directories | `flyte.io.Dir` |
| Pydantic models | Dictionary (Flyte creates a representation) |
Any DataFrame type (pandas, polars, spark) automatically becomes `flyte.io.DataFrame`, allowing seamless data exchange between tasks using different DataFrame libraries. You can also write custom integrations or explore Flyte's plugin system for additional types.
For Pydantic models specifically, you don't need the exact model locally. Pass a dictionary as input, and Flyte will handle the translation.
## Versioning options
Reference tasks support flexible versioning:
**Specific version**:
```python
task = flyte.remote.Task.get(
"team_a.process_data",
version="v1.2.3"
)
```
**Latest version** (`auto_version="latest"`):
```python
# Always use the most recently deployed version
task = flyte.remote.Task.get(
"team_a.process_data",
auto_version="latest"
)
```
**Current version** (`auto_version="current"`):
```python
# Use the same version as the calling task's deployment
# Useful when all environments deploy with the same version
# Can only be used from within a task context
task = flyte.remote.Task.get(
"team_a.process_data",
auto_version="current"
)
```
## Best practices
### 1. Use meaningful task names
Remote tasks are accessed by name, so use clear, descriptive naming:
```python
# Good
customer_segmentation = flyte.remote.Task.get("ml_platform.customer_segmentation")
# Avoid
task1 = flyte.remote.Task.get("team_a.task1")
```
### 2. Document task interfaces
Since remote tasks abstract away implementation details, clear documentation of inputs, outputs, and behavior is essential:
```python
@env.task
async def process_customer_data(
customer_ids: list[str],
date_range: tuple[str, str]
) -> flyte.io.DataFrame:
"""
Process customer data for the specified date range.
Args:
customer_ids: List of customer IDs to process
date_range: Tuple of (start_date, end_date) in YYYY-MM-DD format
Returns:
DataFrame with processed customer features
"""
...
```
### 3. Handle versioning thoughtfully
- Use `auto_version="latest"` during development for rapid iteration
- Use specific versions in production for stability and reproducibility
- Use `auto_version="current"` when coordinating multienvironment deployments
### 4. Deploy remote tasks first
Always deploy the remote tasks before using them. Tasks that reference them can be run directly without deployment:
```bash
# Deploy the Spark environment first
flyte deploy spark_env/
# Deploy the ML environment
flyte deploy ml_env/
# Now you can run the orchestration task directly (no deployment needed)
flyte run orchestration_env.py orchestrate_pipeline
```
If you want to deploy the orchestration task as well (for scheduled runs or to be referenced by other tasks), deploy it after its dependencies:
```bash
flyte deploy orchestration_env/
```
## Limitations
1. **Type fidelity**: While Flyte translates types seamlessly, you work with Flyte's representation of Pydantic models, not the exact original types
2. **Deployment order**: Referenced tasks must be deployed before tasks that reference them
3. **Context requirement**: Using `auto_version="current"` requires running within a task context
4. **Dictionary inputs**: Pydantic models must be passed as dictionaries, which loses compile-time type checking
## Next steps
- Learn about **Run and deploy tasks**
- Explore **Configure tasks**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/error-handling ===
# Error handling
One of the key features of Flyte 2 is the ability to recover from user-level errors in a workflow execution.
This includes out-of-memory errors and other exceptions.
In a distributed system with heterogeneous compute, certain types of errors are expected and even, in a sense, acceptable.
Flyte 2 recognizes this and allows you to handle them gracefully as part of your workflow logic.
This ability is a direct result of the fact that workflows are now written in regular Python,
giving you with all the power and flexibility of Python error handling.
Let's look at an example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = ""
# ///
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="fail", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
@env.task
async def main() -> int:
try:
await oomer(2)
except flyte.errors.OOMError as e:
print(f"Failed with oom trying with more resources: {e}, of type {type(e)}, {e.code}")
try:
await oomer.override(resources=flyte.Resources(cpu=1, memory="1Gi"))(5)
except flyte.errors.OOMError as e:
print(f"Failed with OOM Again giving up: {e}, of type {type(e)}, {e.code}")
raise e
finally:
await always_succeeds()
return await always_succeeds()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/error-handling/error_handling.py)
In this code, we do the following:
* Import the necessary modules
* Set up the task environment. Note that we define our task environment with a resource allocation of 1 CPU and 250 MiB of memory.
* Define two tasks: one that will intentionally cause an out-of-memory (OOM) error, and another that will always succeed.
* Define the main task (the top level workflow task) that will handle the failure recovery logic.
The top `try...catch` block attempts to run the `oomer` task with a parameter that is likely to cause an OOM error.
If the error occurs, it catches the **Build tasks > Error handling > `flyte.errors.OOMError`** and attempts to run the `oomer` task again with increased resources.
This type of dynamic error handling allows you to gracefully recover from user-level errors in your workflows.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/traces ===
# Traces
The `@flyte.trace` decorator provides fine-grained observability and resumption capabilities for functions called within your Flyte workflows.
Traces are used on **helper functions** that tasks call to perform specific operations like API calls, data processing, or computations.
Traces are particularly useful for **Considerations > Non-deterministic behavior**, allowing you to track execution details and resume from failures.
## What are traced functions for?
At the top level, Flyte workflows are composed of **tasks**. But it is also common practice to break down complex task logic into smaller, reusable functions by defining helper functions that tasks call to perform specific operations.
Any helper functions defined or imported into the same file as a task definition are automatically uploaded to the Flyte environment alongside the task when it is deployed.
At the task level, observability and resumption of failed executions is provided by caching, but what if you want these capabilities at a more granular level, for the individual operations that tasks perform?
This is where **traced functions** come in. By decorating helper functions with `@flyte.trace`, you enable:
- **Detailed observability**: Track execution time, inputs/outputs, and errors for each function call.
- **Fine-grained resumption**: If a workflow fails, resume from the last successful traced function instead of re-running the entire task.
Each traced function is effectively a checkpoint within its task.
Here is an example:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("env")
@flyte.trace
async def call_llm(prompt: str) -> str:
await asyncio.sleep(0.1)
return f"LLM response for: {prompt}"
@flyte.trace
async def process_data(data: str) -> dict:
await asyncio.sleep(0.2)
return {"processed": data, "status": "completed"}
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await call_llm(f"Generate research plan for: {topic}")
processed_data = await process_data(llm_result)
return {"topic": topic, "result": processed_data}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/task_vs_trace.py)
## What Gets Traced
Traces capture detailed execution information:
- **Execution time**: How long each function call takes.
- **Inputs and outputs**: Function parameters and return values.
- **Checkpoints**: State that enables workflow resumption.
### Errors are not recorded
Only successful trace executions are recorded in the checkpoint system. When a traced function fails, the exception propagates up to your task code where you can handle it with standard error handling patterns.
### Supported Function Types
The trace decorator works with:
- **Asynchronous functions**: Functions defined with `async def`.
- **Generator functions**: Functions that `yield` values.
- **Async generators**: Functions that `async yield` values.
> [!NOTE]
> Currently tracing only works for asynchronous functions. Tracing of synchronous functions is coming soon.
```
@flyte.trace
async def async_api_call(topic: str) -> dict:
# Asynchronous API call
await asyncio.sleep(0.1)
return {"data": ["item1", "item2", "item3"], "status": "success"}
@flyte.trace
async def stream_data(items: list[str]):
# Async generator function for streaming
for item in items:
await asyncio.sleep(0.02)
yield f"Processing: {item}"
@flyte.trace
async def async_stream_llm(prompt: str):
# Async generator for streaming LLM responses
chunks = ["Research shows", " that machine learning", " continues to evolve."]
for chunk in chunks:
await asyncio.sleep(0.05)
yield chunk
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await async_api_call(topic)
# Collect async generator results
processed_data = []
async for item in stream_data(llm_result["data"]):
processed_data.append(item)
llm_stream = []
async for chunk in async_stream_llm(f"Summarize research on {topic}"):
llm_stream.append(chunk)
return {
"topic": topic,
"processed_data": processed_data,
"llm_summary": "".join(llm_stream)
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/function_types.py)
## Task Orchestration Pattern
The typical Flyte workflow follows this pattern:
```
@flyte.trace
async def search_web(query: str) -> list[dict]:
# Search the web and return results
await asyncio.sleep(0.1)
return [{"title": f"Article about {query}", "content": f"Content on {query}"}]
@flyte.trace
async def summarize_content(content: str) -> str:
# Summarize content using LLM
await asyncio.sleep(0.1)
return f"Summary of {len(content.split())} words"
@flyte.trace
async def extract_insights(summaries: list[str]) -> dict:
# Extract insights from summaries
await asyncio.sleep(0.1)
return {"insights": ["key theme 1", "key theme 2"], "count": len(summaries)}
@env.task
async def research_pipeline(topic: str) -> dict:
# Each helper function creates a checkpoint
search_results = await search_web(f"research on {topic}")
summaries = []
for result in search_results:
summary = await summarize_content(result["content"])
summaries.append(summary)
final_insights = await extract_insights(summaries)
return {
"topic": topic,
"insights": final_insights,
"sources_count": len(search_results)
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/pattern.py)
**Benefits of this pattern:**
- If `search_web` succeeds but `summarize_content` fails, resumption skips the search step
- Each operation is independently observable and debuggable
- Clear separation between workflow coordination (task) and execution (traced functions)
## Relationship to Caching and Checkpointing
Understanding how traces work with Flyte's other execution features:
| Feature | Scope | Purpose | Default Behavior |
|---------|-------|---------|------------------|
| **Task Caching** | Entire task execution (`@env.task`) | Skip re-running tasks with same inputs | Enabled (`cache="auto"`) |
| **Traces** | Individual helper functions | Observability and fine-grained resumption | Manual (requires `@flyte.trace`) |
| **Checkpointing** | Workflow state | Resume workflows from failure points | Automatic when traces are used |
### How They Work Together
```
@flyte.trace
async def traced_data_cleaning(dataset_id: str) -> List[str]:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.2)
return [f"cleaned_record_{i}_{dataset_id}" for i in range(100)]
@flyte.trace
async def traced_feature_extraction(data: List[str]) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.3)
return {
"features": [f"feature_{i}" for i in range(10)],
"feature_count": len(data),
"processed_samples": len(data)
}
@flyte.trace
async def traced_model_training(features: dict) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.4)
sample_count = features["processed_samples"]
# Mock accuracy based on sample count
accuracy = min(0.95, 0.7 + (sample_count / 1000))
return {
"accuracy": accuracy,
"epochs": 50,
"model_size": "125MB"
}
@env.task(cache="auto") # Task-level caching enabled
async def data_pipeline(dataset_id: str) -> dict:
# 1. If this exact task with these inputs ran before,
# the entire task result is returned from cache
# 2. If not cached, execution begins and each traced function
# creates checkpoints for resumption
cleaned_data = await traced_data_cleaning(dataset_id) # Checkpoint 1
features = await traced_feature_extraction(cleaned_data) # Checkpoint 2
model_results = await traced_model_training(features) # Checkpoint 3
# 3. If workflow fails at step 3, resumption will:
# - Skip traced_data_cleaning (checkpointed)
# - Skip traced_feature_extraction (checkpointed)
# - Re-run only traced_model_training
return {"dataset_id": dataset_id, "accuracy": model_results["accuracy"]}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/caching_vs_checkpointing.py)
### Execution Flow
1. **Task Submission**: Task is submitted with input parameters
2. **Cache Check**: Flyte checks if identical task execution exists in cache
3. **Cache Hit**: If cached, return cached result immediately (no traces needed)
4. **Cache Miss**: Begin fresh execution
5. **Trace Checkpoints**: Each `@flyte.trace` function creates resumption points
6. **Failure Recovery**: If workflow fails, resume from last successful checkpoint
7. **Task Completion**: Final result is cached for future identical inputs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/grouping-actions ===
# Grouping actions
Groups are an organizational feature in Flyte that allow you to logically cluster related task invocations (called "actions") for better visualization and management in the UI.
Groups help you organize task executions into manageable, hierarchical structures regardless of whether you're working with large fanouts or smaller, logically-related sets of operations.
## What are groups?
Groups provide a way to organize task invocations into logical units in the Flyte UI.
When you have multiple task executionsβwhether from large **Build tasks > Fanout**, sequential operations, or any combination of tasksβgroups help organize them into manageable units.
### The problem groups solve
Without groups, complex workflows can become visually overwhelming in the Flyte UI:
- Multiple task executions appear as separate nodes, making it hard to see the high-level structure
- Related operations are scattered throughout the workflow graph
- Debugging and monitoring becomes difficult when dealing with many individual task executions
Groups solve this by:
- **Organizing actions**: Multiple task executions within a group are presented as a hierarchical "folder" structure
- **Improving UI visualization**: Instead of many individual nodes cluttering the view, you see logical groups that can be collapsed or expanded
- **Aggregating status information**: Groups show aggregated run status (success/failure) of their contained actions when you hover over them in the UI
- **Maintaining execution parallelism**: Tasks still run concurrently as normal, but are organized for display
### How groups work
Groups are declared using the **Flyte SDK > Packages > flyte > Methods > group()** context manager.
Any task invocations that occur within the `with flyte.group()` block are automatically associated with that group:
```python
with flyte.group("my-group-name"):
# All task invocations here belong to "my-group-name"
result1 = await task_a(data)
result2 = await task_b(data)
result3 = await task_c(data)
```
The key points about groups:
1. **Context-based**: Use the `with flyte.group("name"):` context manager.
2. **Organizational tool**: Task invocations within the context are grouped together in the UI.
3. **UI folders**: Groups appear as collapsible/expandable folders in the Flyte UI run tree.
4. **Status aggregation**: Hover over a group in the UI to see aggregated success/failure information.
5. **Execution unchanged**: Tasks still execute in parallel as normal; groups only affect organization and visualization.
**Important**: Groups do not aggregate outputs. Each task execution still produces its own individual outputs. Groups are purely for organization and UI presentation.
## Common grouping patterns
### Sequential operations
Group related sequential operations that logically belong together:
```
@env.task
async def data_pipeline(raw_data: str) -> str:
with flyte.group("data-validation"):
validated_data = await process_data(raw_data, "validate_schema")
validated_data = await process_data(validated_data, "check_quality")
validated_data = await process_data(validated_data, "remove_duplicates")
with flyte.group("feature-engineering"):
features = await process_data(validated_data, "extract_features")
features = await process_data(features, "normalize_features")
features = await process_data(features, "select_features")
with flyte.group("model-training"):
model = await process_data(features, "train_model")
model = await process_data(model, "validate_model")
final_model = await process_data(model, "save_model")
return final_model
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Parallel processing with groups
Groups work well with parallel execution patterns:
```
@env.task
async def parallel_processing_example(n: int) -> str:
tasks = []
with flyte.group("parallel-processing"):
# Collect all task invocations first
for i in range(n):
tasks.append(process_item(i, "transform"))
# Execute all tasks in parallel
results = await asyncio.gather(*tasks)
# Convert to string for consistent return type
return f"parallel_results: {results}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Multi-phase workflows
Use groups to organize different phases of complex workflows:
```
@env.task
async def multi_phase_workflow(data_size: int) -> str:
# First phase: data preprocessing
preprocessed = []
with flyte.group("preprocessing"):
for i in range(data_size):
preprocessed.append(process_item(i, "preprocess"))
phase1_results = await asyncio.gather(*preprocessed)
# Second phase: main processing
processed = []
with flyte.group("main-processing"):
for result in phase1_results:
processed.append(process_item(result, "transform"))
phase2_results = await asyncio.gather(*processed)
# Third phase: postprocessing
postprocessed = []
with flyte.group("postprocessing"):
for result in phase2_results:
postprocessed.append(process_item(result, "postprocess"))
final_results = await asyncio.gather(*postprocessed)
# Convert to string for consistent return type
return f"multi_phase_results: {final_results}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Nested groups
Groups can be nested to create hierarchical organization:
```
@env.task
async def hierarchical_example(raw_data: str) -> str:
with flyte.group("machine-learning-pipeline"):
with flyte.group("data-preparation"):
cleaned_data = await process_data(raw_data, "clean_data")
split_data = await process_data(cleaned_data, "split_dataset")
with flyte.group("model-experiments"):
with flyte.group("hyperparameter-tuning"):
best_params = await process_data(split_data, "tune_hyperparameters")
with flyte.group("model-training"):
model = await process_data(best_params, "train_final_model")
return model
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Conditional grouping
Groups can be used with conditional logic:
```
@env.task
async def conditional_processing(use_advanced_features: bool, input_data: str) -> str:
base_result = await process_data(input_data, "basic_processing")
if use_advanced_features:
with flyte.group("advanced-features"):
enhanced_result = await process_data(base_result, "advanced_processing")
optimized_result = await process_data(enhanced_result, "optimize_result")
return optimized_result
else:
with flyte.group("basic-features"):
simple_result = await process_data(base_result, "simple_processing")
return simple_result
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
## Key insights
Groups are primarily an organizational and UI visualization toolβthey don't change how your tasks execute or aggregate their outputs, but they help organize related task invocations (actions) into collapsible folder-like structures for better workflow management and display. The aggregated status information (success/failure rates) is visible when hovering over group folders in the UI.
Groups make your Flyte workflows more maintainable and easier to understand, especially when working with complex workflows that involve multiple logical phases or large numbers of task executions. They serve as organizational "folders" in the UI's call stack tree, allowing you to collapse sections to reduce visual distraction while still seeing aggregated status information on hover.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-programming/fanout ===
# Fanout
Flyte is designed to scale effortlessly, allowing you to run workflows with large fan-outs.
When you need to execute many tasks in parallelβsuch as processing a large dataset or running hyperparameter sweepsβFlyte provides powerful patterns to implement these operations efficiently.
## Understanding fanout
A "fanout" pattern occurs when you spawn multiple tasks concurrently.
Each task runs in its own container and contributes an output that you later collect.
The most common way to implement this is using the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.
In Flyte terminology, each individual task execution is called an "action"βthis represents a specific invocation of a task with particular inputs. When you call a task multiple times in a loop, you create multiple actions.
## Example
We start by importing our required packages, defining our Flyte environment, and creating a simple task that fetches user data from a mock API.
```
import asyncio
from typing import List, Tuple
import flyte
env = flyte.TaskEnvironment("fanout_env")
@env.task
async def fetch_data(user_id: int) -> dict:
"""Simulate fetching user data from an API - good for parallel execution."""
# Simulate network I/O delay
await asyncio.sleep(0.1)
return {
"user_id": user_id,
"name": f"User_{user_id}",
"score": user_id * 10,
"data": f"fetched_data_{user_id}"
}
# {{/docs-fragment setup}} }}
# {{docs-fragment parallel}}
@env.task
async def parallel_data_fetching(user_ids: List[int]) -> List[dict]:
"""Fetch data for multiple users in parallel - ideal for I/O bound operations."""
tasks = []
# Collect all fetch tasks - these can run in parallel since they're independent
for user_id in user_ids:
tasks.append(fetch_data(user_id))
# Execute all fetch operations in parallel
results = await asyncio.gather(*tasks)
return results
# {{/docs-fragment parallel}}
# {{docs-fragment run}}
if __name__ == "__main__":
flyte.init_from_config()
user_ids = [1, 2, 3, 4, 5]
r = flyte.run(parallel_data_fetching, user_ids)
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment run}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
### Parallel execution
Next we implement the most common fanout pattern, which is to collect task invocations and execute them in parallel using `asyncio.gather()`:
```
@env.task
async def parallel_data_fetching(user_ids: List[int]) -> List[dict]:
"""Fetch data for multiple users in parallel - ideal for I/O bound operations."""
tasks = []
# Collect all fetch tasks - these can run in parallel since they're independent
for user_id in user_ids:
tasks.append(fetch_data(user_id))
# Execute all fetch operations in parallel
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
### Running the example
To actually run our example, we create a main guard that intializes Flyte and runs our main driver task:
```
if __name__ == "__main__":
flyte.init_from_config()
user_ids = [1, 2, 3, 4, 5]
r = flyte.run(parallel_data_fetching, user_ids)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
## How Flyte handles concurrency and parallelism
In the example we use a standard `asyncio.gather()` pattern.
When this pattern is used in a normal Python environment, the tasks would execute **concurrently** (cooperatively sharing a single thread through the event loop), but not in true **parallel** (multiple CPU cores simultaneously).
However, **Flyte transforms this concurrency model into true parallelism**. When you use `asyncio.gather()` in a Flyte task:
1. **Flyte acts as a distributed event loop**: Instead of scheduling coroutines on a single machine, Flyte schedules each task action to run in its own container across the cluster
2. **Concurrent becomes parallel**: What would be cooperative multitasking in regular Python becomes true parallel execution across multiple machines
3. **Native Python patterns**: You use familiar `asyncio` patterns, but Flyte automatically distributes the work
This means that when you write:
```python
results = await asyncio.gather(fetch_data(1), fetch_data(2), fetch_data(3))
```
Instead of three coroutines sharing one CPU, you get three separate containers running simultaneously, each with their own CPU, memory, and resources. Flyte seamlessly bridges the gap between Python's concurrency model and distributed parallel computing, allowing for massive scalability while maintaining the familiar async/await programming model.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment ===
# Run and deploy tasks
You have seen how to configure and build the tasks that compose your project.
Now you need to decide how to execute them on your Flyte backend.
Flyte offers two distinct approaches for getting your tasks onto the backend:
**Use `flyte run` when you're iterating and experimenting:**
- Quickly test changes during development
- Try different parameters or code modifications
- Debug issues without creating permanent artifacts
- Prototype new ideas rapidly
**Use `flyte deploy` when your project is ready to be formalized:**
- Freeze a stable version of your tasks for repeated use
- Share tasks with team members or across environments
- Move from experimentation to a more structured workflow
- Create a permanent reference point (not necessarily production-ready)
This section explains both approaches and when to use each one.
## Ephemeral deployment and immediate execution
The `flyte run` CLI command and the `flyte.run()` SDK function are used to **ephemerally deploy** and **immediately execute** a task on the backend in a single step.
The task can be re-run and its execution and outputs can be observed in the **Runs list** UI, but it is not permanently added to the **Tasks list** on the backend.
Let's say you have the following file called `greeting.py`:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
```
### With the `flyte run` CLI command
The general form of the command for running a task from a local file is:
```bash
flyte run
```
So, to run the `greet` task defined in the `greeting.py` file, you would run:
```bash
flyte run greeting.py greet --message "Good morning!"
```
This command:
1. **Temporarily deploys** the task environment named `greeting_env` (held by the variable `env`) that contains the `greet` task.
2. **Executes** the `greet` function with argument `message` set to `"Good morning!"`. Note that `message` is the actual parameter name defined in the function signature.
3. **Returns** the execution results and displays them in the terminal.
### With the `flyte.run()` SDK function
You can also do the same thing programmatically using the `flyte.run()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(greet, message="Good morning!")
print(f"Result: {result}")
```
Here we add a `__main__` block to the `greeting.py` file that initializes the Flyte SDK from the configuration file and then calls `flyte.run()` with the `greet` task and its argument.
Now you can run the `greet` task on the backend just by executing the `greeting.py` file locally as a script:
```bash
python greeting.py
```
For more details on how `flyte run` and `flyte.run()` work under the hood, see **Run and deploy tasks > How task run works**.
## Persistent deployment
The `flyte deploy` CLI command and the `flyte.deploy()` SDK function are used to **persistently deploy** a task environment (and all its contained tasks) to the backend.
The tasks within the deployed environment will appear in the **Tasks list** UI on the backend and can then be executed multiple times without needing to redeploy them.
### With the `flyte deploy` CLI command
The general form of the command for running a task from a local file is:
```bash
flyte deploy
```
So, using the same `greeting.py` file as before, you can deploy the `greeting_env` task environment like this:
```bash
flyte deploy greeting.py env
```
This command deploys the task environment *assigned to the variable `env`* in the `greeting.py` file, which is the `TaskEnvironment` named `greeting_env`.
Notice that you must specify the *variable* to which the `TaskEnvironment` is assigned (`env` in this case), not the name of the environment itself (`greeting_env`).
Deploying a task environment deploys all tasks defined within it. Here, that means all functions decorated with `@env.task`.
In this case there is just one: `greet()`.
### With the `flyte.deploy()` SDK function
You can also do the same thing programmatically using the `flyte.deploy()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(env)
print(deployments[0].summary_repr())
```
Now you can deploy the `greeting_env` task environment (and therefore the `greet()` task) just by executing the `greeting.py` file locally as a script.
```bash
python greeting.py
```
For more details on how `flyte deploy` and `flyte.deploy()` work under the hood, see **Run and deploy tasks > How task deployment works**.
## Running already deployed tasks
If you have already deployed your task environment, you can run its tasks without redeploying by using the `flyte run` CLI command or the `flyte.run()` SDK function in a slightly different way. Alternatively, you can always initiate execution of a deployed task from the UI.
### With the `flyte run` CLI command
To run a permanently deployed task using the `flyte run` CLI command, use the special `deployed-task` keyword followed by the task reference in the format `{environment_name}.{task_name}`. For example, to run the previously deployed `greet` task from the `greeting_env` environment, you would run:
```bash
flyte run deployed-task greeting_env.greet --message "World"
```
Notice that now that the task environment is deployed, you use its name (`greeting_env`), not by the variable name to which it was assigned in source code (`env`).
The task environment name plus the task name (`greet`) are combined with a dot (`.`) to form the full task reference: `greeting_env.greet`.
The special `deployed-task` keyword tells the CLI that you are referring to a task that has already been deployed. In effect, it replaces the file path argument used for ephemeral runs.
When executed, this command will run the already-deployed `greet` task with argument `message` set to `"World"`. You will see the result printed in the terminal. You can also, of course, observe the execution in the **Runs list** UI.
### With the `flyte.run()` SDK function
You can also run already-deployed tasks programmatically using the `flyte.run()` function.
For example, to run the previously deployed `greet` task from the `greeting_env` environment, you would do:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env)
task = flyte.remote.Task.get("greeting_env.greet", auto_version="latest")
result = flyte.run(task, message="Good morning!")
print(f"Result: {result}")
```
When you execute this script locally, it will:
- Deploy the `greeting_env` task environment as before.
- Retrieve the already-deployed `greet` task using `flyte.remote.Task.get()`, specifying its full task reference as a string: `"greeting_env.greet"`.
- Call `flyte.run()` with the retrieved task and its argument.
For more details on how running already-deployed tasks works, see **Run and deploy tasks > How task run works > Running deployed tasks**.
## Subpages
- **Run and deploy tasks > How task run works**
- **Run and deploy tasks > Run command options**
- **Run and deploy tasks > How task deployment works**
- **Run and deploy tasks > Deploy command options**
- **Run and deploy tasks > Code packaging for remote execution**
- **Run and deploy tasks > Deployment patterns**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/how-task-run-works ===
# How task run works
The `flyte run` command and `flyte.run()` SDK function support three primary execution modes:
1. **Ephemeral deployment + run**: Automatically prepare task environments ephemerally and execute tasks (development shortcut)
2. **Run deployed task**: Execute permanently deployed tasks without redeployment
3. **Local execution**: Run tasks on your local machine for development and testing
Additionally, you can run deployed tasks through the Flyte/Union UI for interactive execution and monitoring.
## Ephemeral deployment + run: The development shortcut
The most common development pattern combines ephemeral task preparation and execution in a single command, automatically handling the temporary deployment process when needed.
### CLI: Ephemeral deployment and execution
```bash
# Basic deploy + run
flyte run my_example.py my_task --name "World"
# With explicit project and domain
flyte run --project my-project --domain development my_example.py my_task --name "World
# With deployment options
flyte run --version v1.0.0 --copy-style all my_example.py my_task --name "World"
```
**How it works:**
1. **Environment discovery**: Flyte loads the specified Python file and identifies task environments
2. **Ephemeral preparation**: Temporarily prepares the task environment for execution (similar to deployment but not persistent)
3. **Task execution**: Immediately runs the specified task with provided arguments in the ephemeral environment
4. **Result return**: Returns execution results and monitoring URL
5. **Cleanup**: The ephemeral environment is not stored permanently in the backend
### SDK: Programmatic ephemeral deployment + run
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
# Deploy and run in one step
result = flyte.run(my_task, name="World")
print(f"Result: {result}")
print(f"Execution URL: {result.url}")
```
**Benefits of ephemeral deployment + run:**
- **Development efficiency**: No separate permanent deployment step required
- **Always current**: Uses your latest code changes without polluting the backend
- **Clean development**: Ephemeral environments don't clutter your task registry
- **Integrated workflow**: Single command for complete development cycle
## Running deployed tasks
For production workflows or when you want to use stable deployed versions, you can run tasks that have been **permanently deployed** with `flyte deploy` without triggering any deployment process.
### CLI: Running deployed tasks
```bash
# Run a previously deployed task
flyte run deployed-task my_env.my_task --name "World"
# With specific project/domain
flyte run --project prod --domain production deployed-task my_env.my_task --batch_size 1000
```
**Task reference format:** `{environment_name}.{task_name}`
- `environment_name`: The `name` property of your `TaskEnvironment`
- `task_name`: The function name of your task
>[!NOTE]
> Recall that when you deploy a task environment with `flyte deploy`, you specify the `TaskEnvironment` using the variable to which it is assigned.
> In contrast, once it is deployed, you refer to the environment by its `name` property.
### SDK: Running deployed tasks
```python
import flyte
flyte.init_from_config()
# Method 1: Using remote task reference
deployed_task = flyte.remote.Task.get("my_env.my_task", version="v1.0.0")
result = flyte.run(deployed_task, name="World")
# Method 2: Get latest version
deployed_task = flyte.remote.Task.get("my_env.my_task", auto_version="latest")
result = flyte.run(deployed_task, name="World")
```
**Benefits of running deployed tasks:**
- **Performance**: No deployment overhead, faster execution startup
- **Stability**: Uses tested, stable versions of your code
- **Production safety**: Isolated from local development changes
- **Version control**: Explicit control over which code version runs
## Local execution
For development, debugging, and testing, you can run tasks locally on your machine without any backend interaction.
### CLI: Local execution
```bash
# Run locally with --local flag
flyte run --local my_example.py my_task --name "World"
# Local execution with development data
flyte run --local data_pipeline.py process_data --input_path "/local/data" --debug true
```
### SDK: Local execution
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
# Method 1: No client configured (defaults to local)
result = flyte.run(my_task, name="World")
# Method 2: Explicit local mode
flyte.init_from_config() # Client configured
result = flyte.with_runcontext(mode="local").run(my_task, name="World")
```
**Benefits of local execution:**
- **Rapid development**: Instant feedback without network latency
- **Debugging**: Full access to local debugging tools
- **Offline development**: Works without backend connectivity
- **Resource efficiency**: Uses local compute resources
## Running tasks through the Union UI
If you are running your Flyte code on a Union backend, the UI provides an interactive way to run deployed tasks with form-based input and real-time monitoring.
### Accessing task execution in the Union UI
1. **Navigate to tasks**: Go to your project β domain β Tasks section
2. **Select task**: Choose the task environment and specific task
3. **Launch execution**: Click "Launch" to open the execution form
4. **Provide inputs**: Fill in task parameters through the web interface
5. **Monitor progress**: Watch real-time execution progress and logs
**UI execution benefits:**
- **User-friendly**: No command-line expertise required
- **Visual monitoring**: Real-time progress visualization
- **Input validation**: Built-in parameter validation and type checking
- **Execution history**: Easy access to previous runs and results
- **Sharing**: Shareable execution URLs for collaboration
Here is a short video demonstrating task execution through the Union UI:
πΊ [Watch on YouTube](https://www.youtube.com/watch?v=id="8jbau9yGoDg)
## Execution flow and architecture
### Fast registration architecture
Flyte v2 uses "fast registration" to enable rapid development cycles:
#### How it works
1. **Container images** contain the runtime environment and dependencies
2. **Code bundles** contain your Python source code (stored separately)
3. **At runtime**: Code bundles are downloaded and injected into running containers
#### Benefits
- **Rapid iteration**: Update code without rebuilding images
- **Resource efficiency**: Share images across multiple deployments
- **Version flexibility**: Run different code versions with same base image
- **Caching optimization**: Separate caching for images vs. code
#### When code gets injected
At task execution time, the fast registration process follows these steps:
1. **Container starts** with the base image containing runtime environment and dependencies
2. **Code bundle download**: The Flyte agent downloads your Python code bundle from storage
3. **Code extraction**: The code bundle is extracted and mounted into the running container
4. **Task execution**: Your task function executes with the injected code
### Ephemeral preparation logic
When using ephemeral deploy + run mode, Flyte determines whether temporary preparation is needed:
```mermaid
graph TD
A[flyte run command] --> B{Need preparation?}
B -->|Yes| C[Ephemeral preparation]
B -->|No| D[Use cached preparation]
C --> E[Execute task]
D --> E
E --> F[Cleanup ephemeral environment]
```
### Execution modes comparison
| Mode | Deployment | Performance | Use Case | Code Version |
|------|------------|-------------|-----------|--------------|
| Ephemeral Deploy + Run | Ephemeral (temporary) | Medium | Development, testing | Latest local |
| Run Deployed | None (uses permanent deployment) | Fast | Production, stable runs | Deployed version |
| Local | None | Variable | Development, debugging | Local |
| UI | None | Fast | Interactive, collaboration | Deployed version |
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/run-command-options ===
# Run command options
The `flyte run` command provides the following options:
**`flyte run [OPTIONS] |deployed_task `**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|--------------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to run tasks in |
| `--domain` | `-d` | text | *from config* | Domain to run tasks in |
| `--local` | | flag | `false` | Run the task locally |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--raw-data-path` | | text | | Override the output location for offloaded data types. |
| `--service-account` | | text | | Kubernetes service account. |
| `--name` | | text | | Name of the run. |
| `--follow` | `-f` | flag | `false` | Wait and watch logs for the parent action. |
| `--image` | | text | | Image to be used in the run (format: `name=uri`). |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable synchronization of local sys.path entries. |
## `--project`, `--domain`
**`flyte run --domain --project |deployed_task `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
# Use defaults from default config.yaml
flyte run my_example.py my_task
# Specify target project and domain
flyte run --project my-project --domain development my_example.py my_task
```
## `--local`
**`flyte run --local `**
The `--local` option runs tasks locally instead of submitting them to the remote Flyte backend:
```bash
# Run task locally (default behavior when using flyte.run() without deployment)
flyte run --local my_example.py my_task --input "test_data"
# Compare with remote execution
flyte run my_example.py my_task --input "test_data" # Runs on Flyte backend
```
### When to use local execution
- **Development and testing**: Quick iteration without deployment overhead
- **Debugging**: Full access to local debugging tools and environment
- **Resource constraints**: When remote resources are unavailable or expensive
- **Data locality**: When working with large local datasets
## `--copy-style`
**`flyte run --copy-style [loaded_modules|all|none] `**
The `--copy-style` option controls code bundling for remote execution.
This applies to the ephemeral preparation step of the `flyte run` command and works similarly to `flyte deploy`:
```bash
# Smart bundling (default) - includes only imported project modules
flyte run --copy-style loaded_modules my_example.py my_task
# Include all project files
flyte run --copy-style all my_example.py my_task
# No code bundling (task must be pre-deployed)
flyte run --copy-style none deployed_task my_deployed_task
```
### Copy style options
- **`loaded_modules` (default)**: Bundles only imported Python modules from your project
- **`all`**: Includes all files in the project directory
- **`none`**: No bundling; requires permanently deployed tasks
## `--root-dir`
**`flyte run --root-dir `**
Override the source directory for code bundling and import resolution:
```bash
# Run from monorepo root with specific root directory
flyte run --root-dir ./services/ml ./services/ml/my_example.py my_task
# Handle cross-directory imports
flyte run --root-dir .. my_example.py my_workflow # When my_example.py imports sibling directories
```
This applies to the ephemeral preparation step of the `flyte run` command.
It works identically to the `flyte deploy` command's `--root-dir` option.
## `--raw-data-path`
**`flyte run --raw-data-path `**
Override the default output location for offloaded data types (large objects, DataFrames, etc.):
```bash
# Use custom S3 location for large outputs
flyte run --raw-data-path s3://my-bucket/custom-path/ my_example.py process_large_data
# Use local directory for development
flyte run --local --raw-data-path ./output/ my_example.py my_task
```
### Use cases
- **Custom storage locations**: Direct outputs to specific S3 buckets or paths
- **Cost optimization**: Use cheaper storage tiers for temporary data
- **Access control**: Ensure outputs go to locations with appropriate permissions
- **Local development**: Store large outputs locally when testing
## `--service-account`
**`flyte run --service-account `**
Specify a Kubernetes service account for task execution:
```bash
# Run with specific service account for cloud resource access
flyte run --service-account ml-service-account my_example.py train_model
# Use service account with specific permissions
flyte run --service-account data-reader-sa my_example.py load_data
```
### Use cases
- **Cloud resource access**: Service accounts with permissions for S3, GCS, etc.
- **Security isolation**: Different service accounts for different workload types
- **Compliance requirements**: Enforcing specific identity and access policies
## `--name`
**`flyte run --name `**
Provide a custom name for the execution run:
```bash
# Named execution for easy identification
flyte run --name "daily-training-run-2024-12-02" my_example.py train_model
# Include experiment parameters in name
flyte run --name "experiment-lr-0.01-batch-32" my_example.py hyperparameter_sweep
```
### Benefits of custom names
- **Easy identification**: Find specific runs in the Flyte console
- **Experiment tracking**: Include key parameters or dates in names
- **Automation**: Programmatically generate meaningful names for scheduled runs
## `--follow`
**`flyte run --follow `**
Wait and watch logs for the execution in real-time:
```bash
# Stream logs to console and wait for completion
flyte run --follow my_example.py long_running_task
# Combine with other options
flyte run --follow --name "training-session" my_example.py train_model
```
### Behavior
- **Log streaming**: Real-time output from task execution
- **Blocking execution**: Command waits until task completes
- **Exit codes**: Returns appropriate exit code based on task success/failure
## `--image`
**`flyte run --image `**
Override container images during ephemeral preparation, same as the equivalent `flyte deploy` option:
```bash
# Override specific named image
flyte run --image gpu=ghcr.io/org/gpu:v2.1 my_example.py gpu_task
# Override default image
flyte run --image ghcr.io/org/custom:latest my_example.py my_task
# Multiple image overrides
flyte run \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py multi_env_workflow
```
### Image mapping formats
- **Named mapping**: `name=uri` overrides images created with `Image.from_ref_name("name")`
- **Default mapping**: `uri` overrides the default "auto" image
- **Multiple mappings**: Use multiple `--image` flags for different image references
## `--no-sync-local-sys-paths`
**`flyte run --no-sync-local-sys-paths `**
Disable synchronization of local `sys.path` entries to the remote execution environment during ephemeral preparation.
Identical to the `flyte deploy` command's `--no-sync-local-sys-paths` option:
```bash
# Disable path synchronization for clean container environment
flyte run --no-sync-local-sys-paths my_example.py my_task
```
This advanced option works identically to the deploy command equivalent, useful for:
- **Container isolation**: Prevent local development paths from affecting remote execution
- **Custom environments**: When containers have pre-configured Python paths
- **Security**: Avoiding exposure of local directory structures
## Task argument passing
Arguments are passed directly as function parameters:
```bash
# CLI: Arguments as flags
flyte run my_file.py my_task --name "World" --count 5 --debug true
# SDK: Arguments as function parameters
result = flyte.run(my_task, name="World", count=5, debug=True)
```
## SDK options
The core `flyte run` functionality is also available programmatically through the `flyte.run()` function, with extensive configuration options available via the `flyte.with_runcontext()` function:
```python
# Run context configuration
result = flyte.with_runcontext(
mode="remote", # "remote", "local"
copy_style="loaded_modules", # Code bundling strategy
version="v1.0.0", # Ephemeral preparation version
dry_run=False, # Preview mode
).run(my_task, name="World")
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/how-task-deployment-works ===
# How task deployment works
In this section, we will take a deep dive into how the `flyte deploy` command and the `flyte.deploy()` SDK function work under the hood to deploy tasks to your Flyte backend.
When you perform a deployment, here's what happens:
## 1. Module loading and task environment discovery
In the first step, Flyte determines which files to load in order to search for task environments, based on the command line options provided:
### Single file (default)
```bash
flyte deploy my_example.py env
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated,
but only the one assigned to the variable `env` is selected for deployment.
### `--all` option
```bash
flyte deploy --all my_example.py
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated and selected for deployment.
- No specific variable name is required.
### `--recursive` option
```bash
flyte deploy --recursive ./directory
```
- The directory is recursively traversed and all Python files are executed and all `TaskEnvironment` objects are instantiated.
- All `TaskEnvironment` objects across all files are selected for deployment.
## 2. Task analysis and serialization
- For every task environment selected for deployment, all of its tasks are identified.
- Task metadata is extracted: parameter types, return types, and resource requirements.
- Each task is serialized into a Flyte `TaskTemplate`.
- Dependency graphs between environments are built (see below).
## 3. Task environment dependency resolution
In many cases, a task in one environment may invoke a task in another environment, establishing a dependency between the two environments.
For example, if `env_a` has a task that calls a task in `env_b`, then `env_a` depends on `env_b`.
This means that when deploying `env_a`, `env_b` must also be deployed to ensure that all tasks can be executed correctly.
To handle this, `TaskEnvironment`s can declare dependencies on other `TaskEnvironment`s using the `depends_on` parameter.
During deployment, the system performs the following steps to resolve these dependencies:
1. Starting with specified environment(s)
2. Recursively discovering all transitive dependencies
3. Including all dependencies in the deployment plan
4. Processing dependencies depth-first to ensure correct order
```python
# Define environments with dependencies
prep_env = flyte.TaskEnvironment(name="preprocessing")
ml_env = flyte.TaskEnvironment(name="ml_training", depends_on=[prep_env])
viz_env = flyte.TaskEnvironment(name="visualization", depends_on=[ml_env])
# Deploy only viz_env - automatically includes ml_env and prep_env
deployment = flyte.deploy(viz_env, version="v2.0.0")
# Or deploy multiple environments explicitly
deployment = flyte.deploy(data_env, ml_env, viz_env, version="v2.0.0")
```
For detailed information about working with multiple environments, see **Configure tasks > Multiple environments**.
## 4. Code bundle creation and upload
Once the task environments and their dependencies are resolved, Flyte proceeds to package your code into a bundle based on the `copy_style` option:
### `--copy_style loaded_modules` (default)
This is the smart bundling approach that analyzes which Python modules were actually imported during the task environment discovery phase.
It examines the runtime module registry (`sys.modules`) and includes only those modules that meet specific criteria:
they must have source files located within your project directory (not in system locations like `site-packages`), and they must not be part of the Flyte SDK itself.
This selective approach results in smaller, faster-to-upload bundles that contain exactly the code needed to run your tasks, making it ideal for most development and production scenarios.
### `--copy_style all`
This comprehensive bundling strategy takes a directory-walking approach, recursively traversing your entire project directory and including every file it encounters.
Unlike the smart bundling that only includes imported Python modules, this method captures all project files regardless of whether they were imported during discovery.
This is particularly useful for projects that use dynamic imports, load configuration files or data assets at runtime, or have dependencies that aren't captured through normal Python import mechanisms.
### `--copy_style none`
This option completely skips code bundle creation, meaning no source code is packaged or uploaded to cloud storage.
When using this approach, you must provide an explicit version parameter since there's no code bundle to generate a version from.
This strategy is designed for scenarios where your code is already baked into custom container images, eliminating the need for separate code injection during task execution.
It results in the fastest deployment times but requires more complex image management workflows.
### `--root-dir` option
By default, Flyte uses your current working directory as the root for code bundling.
You can override this with `--root-dir` to specify a different base directory - particularly useful for monorepos or when deploying from subdirectories. This affects all copy styles: `loaded_modules` will look for imported modules relative to the root directory, `all` will walk the directory tree starting from the root, and the root directory setting works with any copy style. See the **Run and deploy tasks > Deploy command options > `--root-dir`** for detailed usage examples.
After the code bundle is created (if applicable), it is uploaded to a cloud storage location (like S3 or GCS) accessible by your Flyte backend. It is now ready to be run.
## 5. Image building
If your `TaskEnvironment` specifies **Configure tasks > Container images**, Flyte builds and pushes container images before deploying tasks.
The build process varies based on your configuration and backend type:
### Local image building
When `image.builder` is set to `local` in **Getting started > Local setup**, images are built on your local machine using Docker. This approach:
- Requires Docker to be installed and running on your development machine
- Uses Docker BuildKit to build images from generated Dockerfiles or your custom Dockerfile
- Pushes built images to the container registry specified in your `Image` configuration
- Is the only option available for Flyte OSS instances
### Remote image building
When `image.builder` is set to `remote` in **Getting started > Local setup**, images are built on cloud infrastructure. This approach:
- Builds images using Union's ImageBuilder service (currently only available for Union backends, not OSS Flyte)
- Requires no local Docker installation or configuration
- Can push to Union's internal registry or external registries you specify
- Provides faster, more consistent builds by leveraging cloud resources
> [!NOTE]
> Remote building is currently exclusive to Union backends. OSS Flyte installations must use `local`
## Understanding option relationships
It's important to understand how the various deployment options work together.
The **discovery options** (`--recursive` and `--all`) operate independently of the **bundling options** (`--copy-style`),
giving you flexibility in how you structure your deployments.
Environment discovery determines which files Flyte will examine to find `TaskEnvironment` objects,
while code bundling controls what gets packaged and uploaded for execution.
You can freely combine these approaches.
For example, discovering environments recursively across your entire project while using smart bundling to include only the necessary code modules.
When multiple environments are discovered, they all share the same code bundle, which is efficient for related services or components that use common dependencies:
```bash
# All discovered environments share the same code bundle
flyte deploy --recursive --copy-style loaded_modules ./project
```
For a full overview of all deployment options, see **Flyte CLI > flyte > flyte deploy**.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/deploy-command-options ===
# Deploy command options
The `flyte deploy` command provides extensive configuration options:
**`flyte deploy [OPTIONS] [TASK_ENV_VARIABLE]`**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|---------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to deploy to |
| `--domain` | `-d` | text | *from config* | Domain to deploy to |
| `--version` | | text | *auto-generated* | Explicit version tag for deployment |
| `--dry-run`/`--dryrun` | | flag | `false` | Preview deployment without executing |
| `--all` | | flag | `false` | Deploy all environments in specified path |
| `--recursive` | `-r` | flag | `false` | Deploy environments recursively in subdirectories |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--image` | | text | | Image URI mappings (format: `name=uri`) |
| `--ignore-load-errors` | `-i` | flag | `false` | Continue deployment despite module load failures |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable local `sys.path` synchronization |
## `--project`, `--domain`
**`flyte deploy --domain --project `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
# Use defaults from default config.yaml
flyte deploy my_example.py env
# Specify target project and domain
flyte deploy --project my-project --domain development my_example.py env
```
## `--version`
**`flyte deploy --version `**
The `--version` option controls how deployed tasks are tagged and identified in the Flyte backend:
```bash
# Auto-generated version (default)
flyte deploy my_example.py env
# Explicit version
flyte deploy --version v1.0.0 my_example.py env
# Required when using copy-style none (no code bundle to generate hash from)
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
### When versions are used
- **Explicit versioning**: Provides human-readable task identification (e.g., `v1.0.0`, `prod-2024-12-01`)
- **Auto-generated versions**: When no version is specified, Flyte creates an MD5 hash from the code bundle, environment configuration, and image cache
- **Version requirement**: `copy-style none` mandates explicit versions since there's no code bundle to hash
- **Task referencing**: Versions enable precise task references in `flyte run deployed-task` and workflow invocations
## `--dry-run`
**`flyte deploy --dry-run `**
The `--dry-run` option allows you to preview what would be deployed without actually performing the deployment:
```bash
# Preview what would be deployed
flyte deploy --dry-run my_example.py env
```
## `--all` and `--recursive`
**`flyte deploy --all `**
**`flyte deploy --recursive `**
Control which environments get discovered and deployed:
**Single environment (default):**
```bash
# Deploy specific environment variable
flyte deploy my_example.py env
```
**All environments in file:**
```bash
# Deploy all TaskEnvironment objects in file
flyte deploy --all my_example.py
```
**Recursive directory deployment:**
```bash
# Deploy all environments in directory tree
flyte deploy --recursive ./src
# Combine with comprehensive bundling
flyte deploy --recursive --copy-style all ./project
```
## `--copy-style`
**`flyte deploy --copy_style [loaded_modules|all|none] `**
The `--copy-style` option controls what gets packaged:
### `--copy-style loaded_modules` (default)
```bash
flyte deploy --copy-style loaded_modules my_example.py env
```
- **Includes**: Only imported Python modules from your project
- **Excludes**: Site-packages, system modules, Flyte SDK
- **Best for**: Most projects (optimal size and speed)
### `--copy-style all`
```bash
flyte deploy --copy-style all my_example.py env
```
- **Includes**: All files in project directory
- **Best for**: Projects with dynamic imports or data files
### `--copy-style none`
```bash
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
- **Requires**: Explicit version parameter
- **Best for**: Pre-built container images with baked-in code
## `--root-dir`
**`flyte deploy --root-dir `**
The `--root-dir` option overrides the default source directory that Flyte uses as the base for code bundling and import resolution.
This is particularly useful for monorepos and projects with complex directory structures.
### Default behavior (without `--root-dir`)
- Flyte uses the current working directory as the root
- Code bundling starts from this directory
- Import paths are resolved relative to this location
### Common use cases
**Monorepos:**
```bash
# Deploy service from monorepo root
flyte deploy --root-dir ./services/ml ./services/ml/my_example.py env
# Deploy from anywhere in the monorepo
cd ./docs/
flyte deploy --root-dir ../services/ml ../services/ml/my_example.py env
```
**Cross-directory imports:**
```bash
# When workflow imports modules from sibling directories
# Project structure: project/workflows/my_example.py imports project/src/utils.py
cd project/workflows/
flyte deploy --root-dir .. my_example.py env # Sets root to project/
```
**Working directory independence:**
```bash
# Deploy from any location while maintaining consistent bundling
flyte deploy --root-dir /path/to/project /path/to/project/my_example.py env
```
### How it works
1. **Code bundling**: Files are collected starting from `--root-dir` instead of the current working directory
2. **Import resolution**: Python imports are resolved relative to the specified root directory
3. **Path consistency**: Ensures the same directory structure in local and remote execution environments
4. **Dependency packaging**: Captures all necessary modules that may be located outside the workflow file's immediate directory
### Example with complex project structure
```
my-project/
βββ services/
β βββ ml/
β β βββ my_example.py # imports shared.utils
β βββ api/
βββ shared/
βββ utils.py
```
```bash
# Deploy ML service workflows with access to shared utilities
flyte deploy --root-dir ./my-project ./my-project/services/ml/my_example.py env
```
This ensures that both `services/ml/` and `shared/` directories are included in the code bundle, allowing the workflow to successfully import `shared.utils` during remote execution.
## `--image`
**`flyte deploy --image `**
The `--image` option allows you to override image URIs at deployment time without modifying your code. Format: `imagename=imageuri`
### Named image mappings
```bash
# Map specific image reference to URI
flyte deploy --image base=ghcr.io/org/base:v1.0 my_example.py env
# Multiple named image mappings
flyte deploy \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py env
```
### Default image mapping
```bash
# Override default image (used when no specific image is set)
flyte deploy --image ghcr.io/org/default:latest my_example.py env
```
### How it works
- Named mappings (e.g., `base=URI`) override images created with `Image.from_ref_name("base")`.
- Unnamed mappings (e.g., just `URI`) override the default "auto" image.
- Multiple `--image` flags can be specified.
- Mappings are resolved during the image building phase of deployment.
## `--ignore-load-errors`
**`flyte deploy --ignore-load-errors `**
The `--ignore-load-errors` option allows the deployment process to continue even if some modules fail to load during the environment discovery phase. This is particularly useful for large projects or monorepos where certain modules may have missing dependencies or other issues that prevent them from being imported successfully.
```bash
# Continue deployment despite module failures
flyte deploy --recursive --ignore-load-errors ./large-project
```
## `--no-sync-local-sys-paths`
**`flyte deploy --no-sync-local-sys-paths `**
The `--no-sync-local-sys-paths` option disables the automatic synchronization of local `sys.path` entries to the remote container environment. This is an advanced option for specific deployment scenarios.
### Default behavior (path synchronization enabled)
- Flyte captures local `sys.path` entries that are under the root directory
- These paths are passed to the remote container via the `_F_SYS_PATH` environment variable
- At runtime, the remote container adds these paths to its `sys.path`, maintaining the same import environment
### When to disable path synchronization
```bash
# Disable local sys.path sync (advanced use case)
flyte deploy --no-sync-local-sys-paths my_example.py env
```
### Use cases for disabling
- **Custom container images**: When your container already has the correct `sys.path` configuration
- **Conflicting path structures**: When local development paths would interfere with container paths
- **Security concerns**: When you don't want to expose local development directory structures
- **Minimal environments**: When you want precise control over what gets added to the container's Python path
### How it works
- **Enabled (default)**: Local paths like `./my_project/utils` get synchronized and added to remote `sys.path`
- **Disabled**: Only the container's native `sys.path` is used, along with the deployed code bundle
Most users should leave path synchronization enabled unless they have specific requirements for container path isolation or are using pre-configured container environments.
## SDK deployment options
The core deployment functionality is available programmatically through the `flyte.deploy()` function, though some CLI-specific options are not applicable:
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process_data(data: str) -> str:
return f"Processed: {data}"
if __name__ == "__main__":
flyte.init_from_config()
# Comprehensive deployment configuration
deployment = flyte.deploy(
env, # Environment to deploy
dryrun=False, # Set to True for dry run
version="v1.2.0", # Explicit version tag
copy_style="loaded_modules" # Code bundling strategy
)
print(f"Deployment successful: {deployment[0].summary_repr()}")
```
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/packaging ===
# Code packaging for remote execution
When you run Flyte tasks remotely, your code needs to be available in the execution environment. Flyte SDK provides two main approaches for packaging your code:
1. **Code bundling** - Bundle code dynamically at runtime
2. **Container-based deployment** - Embed code directly in container images
## Quick comparison
| Aspect | Code bundling | Container-based |
|--------|---------------|-----------------|
| **Speed** | Fast (no image rebuild) | Slower (requires image build) |
| **Best for** | Rapid development, iteration | Production, immutable deployments |
| **Code changes** | Immediate effect | Requires image rebuild |
| **Setup** | Automatic by default | Manual configuration needed |
| **Reproducibility** | Excellent (hash-based versioning) | Excellent (immutable images) |
| **Rollback** | Requires version control | Tag-based, straightforward |
---
## Code bundling
**Default approach** - Automatically bundles and uploads your code to remote storage at runtime.
### How it works
When you run `flyte run` or call `flyte.run()`, Flyte automatically:
1. **Scans loaded modules** from your codebase
2. **Creates a tarball** (gzipped, without timestamps for consistent hashing)
3. **Uploads to blob storage** (S3, GCS, Azure Blob)
4. **Deduplicates** based on content hashes
5. **Downloads in containers** at runtime
This process happens transparently - every container downloads and extracts the code bundle before execution.
> [!NOTE]
> Code bundling is optimized for speed:
> - Bundles are created without timestamps for consistent hashing
> - Identical code produces identical hashes, enabling deduplication
> - Only modified code triggers new uploads
> - Containers cache downloaded bundles
>
> **Reproducibility:** Flyte automatically versions code bundles based on content hash. The same code always produces the same hash, guaranteeing reproducibility without manual versioning. However, version control is still recommended for rollback capabilities.
### Automatic code bundling
**Default behavior** - Bundles all loaded modules automatically.
#### What gets bundled
Flyte includes modules that are:
- β **Loaded when environment is parsed** (imported at module level)
- β **Part of your codebase** (not system packages)
- β **Within your project directory**
- β **NOT lazily loaded** (imported inside functions)
- β **NOT system-installed packages** (e.g., from site-packages)
#### Example: Basic automatic bundling
```python
# app.py
import flyte
from my_module import helper # β Bundled automatically
env = flyte.TaskEnvironment(
name="default",
image=flyte.Image.from_debian_base().with_pip_packages("pandas", "numpy")
)
@env.task
def process_data(x: int) -> int:
# This import won't be bundled (lazy load)
from another_module import util # β Not bundled automatically
return helper.transform(x)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process_data, x=42)
print(run.url)
```
When you run this:
```bash
flyte run app.py process_data --x 42
```
Flyte automatically:
1. Bundles `app.py` and `my_module.py`
2. Preserves the directory structure
3. Uploads to blob storage
4. Makes it available in the remote container
#### Project structure example
```
my_project/
βββ app.py # Main entry point
βββ tasks/
β βββ __init__.py
β βββ data_tasks.py # Flyte tasks
β βββ ml_tasks.py
βββ utils/
βββ __init__.py
βββ preprocessing.py # Business logic
βββ models.py
```
```python
# app.py
import flyte
from tasks.data_tasks import load_data # β Bundled
from tasks.ml_tasks import train_model # β Bundled
# utils modules imported in tasks are also bundled
@flyte.task
def pipeline(dataset: str) -> float:
data = load_data(dataset)
accuracy = train_model(data)
return accuracy
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(pipeline, dataset="train.csv")
```
**All modules are bundled with their directory structure preserved.**
### Manual code bundling
Control exactly what gets bundled by configuring the copy style.
#### Copy styles
Three options available:
1. **`"auto"`** (default) - Bundle loaded modules only
2. **`"all"`** - Bundle everything in the working directory
3. **`"none"`** - Skip bundling entirely (requires code in container)
#### Using `copy_style="all"`
Bundle all files under your project directory:
```python
import flyte
flyte.init_from_config()
# Bundle everything in current directory
run = flyte.with_runcontext(copy_style="all").run(
my_task,
input_data="sample.csv"
)
```
Or via CLI:
```bash
flyte run --copy-style=all app.py my_task --input-data sample.csv
```
**Use when:**
- You have data files or configuration that tasks need
- You use dynamic imports or lazy loading
- You want to ensure all project files are available
#### Using `copy_style="none"`
Skip code bundling (see **Run and deploy tasks > Code packaging for remote execution > Container-based deployment**):
```python
run = flyte.with_runcontext(copy_style="none").run(my_task, x=10)
```
### Controlling the root directory
The `root_dir` parameter controls which directory serves as the bundling root.
#### Why root directory matters
1. **Determines what gets bundled** - All code paths are relative to root_dir
2. **Preserves import structure** - Python imports must match the bundle structure
3. **Affects path resolution** - Files and modules are located relative to root_dir
#### Setting root directory
##### Via CLI
```bash
flyte run --root-dir /path/to/project app.py my_task
```
##### Programmatically
```python
import pathlib
import flyte
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
#### Root directory use cases
##### Use case 1: Multi-module project
```
project/
βββ src/
β βββ workflows/
β β βββ pipeline.py
β βββ utils/
β βββ helpers.py
βββ config.yaml
```
```python
# src/workflows/pipeline.py
import pathlib
import flyte
from utils.helpers import process # Relative import from project root
# Set root to project root (not src/)
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
@flyte.task
def my_task():
return process()
```
**Root set to `project/` so imports like `from utils.helpers` work correctly.**
##### Use case 2: Shared utilities
```
workspace/
βββ shared/
β βββ common.py
βββ project/
βββ app.py
```
```python
# project/app.py
import flyte
import pathlib
from shared.common import shared_function # Import from parent directory
# Set root to workspace/ to include shared/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent
)
```
##### Use case 3: Monorepo
```
monorepo/
βββ libs/
β βββ data/
β βββ models/
βββ services/
βββ ml_service/
βββ workflows.py
```
```python
# services/ml_service/workflows.py
import flyte
import pathlib
from libs.data import loader # Import from monorepo root
from libs.models import predictor
# Set root to monorepo/ to include libs/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
```
#### Root directory best practices
1. **Set root_dir at project initialization** before importing any task modules
2. **Use absolute paths** with `pathlib.Path(__file__).parent` navigation
3. **Match your import structure** - if imports are relative to project root, set root_dir to project root
4. **Keep consistent** - use the same root_dir for both `flyte run` and `flyte.init()`
### Code bundling examples
#### Example: Standard Python package
```
my_package/
βββ pyproject.toml
βββ src/
β βββ my_package/
β βββ __init__.py
β βββ main.py
β βββ data/
β β βββ loader.py
β β βββ processor.py
β βββ models/
β βββ analyzer.py
```
```python
# src/my_package/main.py
import flyte
import pathlib
from my_package.data.loader import fetch_data
from my_package.data.processor import clean_data
from my_package.models.analyzer import analyze
env = flyte.TaskEnvironment(
name="pipeline",
image=flyte.Image.from_debian_base().with_uv_project(
pyproject_file=pathlib.Path(__file__).parent.parent.parent / "pyproject.toml"
)
)
@env.task
async def fetch_task(url: str) -> dict:
return await fetch_data(url)
@env.task
def process_task(raw_data: dict) -> list[dict]:
return clean_data(raw_data)
@env.task
def analyze_task(data: list[dict]) -> str:
return analyze(data)
if __name__ == "__main__":
import flyte.git
# Set root to project root for proper imports
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent.parent.parent
)
# All modules bundled automatically
run = flyte.run(analyze_task, data=[{"value": 1}, {"value": 2}])
print(f"Run URL: {run.url}")
```
**Run with:**
```bash
cd my_package
flyte run src/my_package/main.py analyze_task --data '[{"value": 1}]'
```
#### Example: Dynamic environment based on domain
```python
# environment_picker.py
import flyte
def create_env():
"""Create different environments based on domain."""
if flyte.current_domain() == "development":
return flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "dev", "DEBUG": "true"}
)
elif flyte.current_domain() == "staging":
return flyte.TaskEnvironment(
name="staging",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "staging", "DEBUG": "false"}
)
else: # production
return flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "production", "DEBUG": "false"},
resources=flyte.Resources(cpu="2", memory="4Gi")
)
env = create_env()
@env.task
async def process(n: int) -> int:
import os
print(f"Running in {os.getenv('ENV')} environment")
return n * 2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process, n=5)
print(run.url)
```
**Why this works:**
- `flyte.current_domain()` is set correctly when Flyte re-instantiates modules remotely
- Environment configuration is deterministic and reproducible
- Code automatically bundled with domain-specific settings
> [!NOTE]
> `flyte.current_domain()` only works after `flyte.init()` is called:
> - β Works with `flyte run` and `flyte deploy` (auto-initialize)
> - β Works in `if __name__ == "__main__"` after explicit `flyte.init()`
> - β Does NOT work at module level without initialization
### When to use code bundling
β **Use code bundling when:**
- Rapid development and iteration
- Frequently changing code
- Multiple developers testing changes
- Jupyter notebook workflows
- Quick prototyping and experimentation
β **Consider container-based instead when:**
- Need easy rollback to previous versions (container tags are simpler than finding git commits)
- Working with air-gapped environments (no blob storage access)
- Code changes require coordinated dependency updates
---
## Container-based deployment
**Advanced approach** - Embed code directly in container images for immutable deployments.
### How it works
Instead of bundling code at runtime:
1. **Build container image** with code copied inside
2. **Disable code bundling** with `copy_style="none"`
3. **Container has everything** needed at runtime
**Trade-off:** Every code change requires a new image build (slower), but provides complete reproducibility.
### Configuration
Three key steps:
#### 1. Set `copy_style="none"`
Disable runtime code bundling:
```python
flyte.with_runcontext(copy_style="none").run(my_task, n=10)
```
Or via CLI:
```bash
flyte run --copy-style=none app.py my_task --n 10
```
#### 2. Copy Code into Image
Use `Image.with_source_file()` or `Image.with_source_folder()`:
```python
import pathlib
import flyte
env = flyte.TaskEnvironment(
name="embedded",
image=flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
```
#### 3. Set Correct `root_dir`
Match your image copy configuration:
```python
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
### Image source copying methods
#### `with_source_file()` - Copy individual files
Copy a single file into the container:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path(__file__),
dst="/app/main.py"
)
```
**Use for:**
- Single-file workflows
- Copying configuration files
- Adding scripts to existing images
#### `with_source_folder()` - Copy directories
Copy entire directories into the container:
```python
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
dst="/app",
copy_contents_only=False # Copy folder itself
)
```
**Parameters:**
- `src`: Source directory path
- `dst`: Destination path in container (optional, defaults to workdir)
- `copy_contents_only`: If `True`, copies folder contents; if `False`, copies folder itself
##### `copy_contents_only=True` (Recommended)
Copies only the contents of the source folder:
```python
# Project structure:
# my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
# Container will have:
# /app/app.py
# /app/utils.py
# Set root_dir to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
##### `copy_contents_only=False`
Copies the folder itself with its name:
```python
# Project structure:
# workspace/
# βββ my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent, # Points to my_project/
copy_contents_only=False
)
# Container will have:
# /app/my_project/app.py
# /app/my_project/utils.py
# Set root_dir to parent to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
### Complete container-based example
```python
# full_build.py
import pathlib
import flyte
from dep import helper # Local module
# Configure environment with source copying
env = flyte.TaskEnvironment(
name="full_build",
image=flyte.Image.from_debian_base()
.with_pip_packages("numpy", "pandas")
.with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
@env.task
def square(x: int) -> int:
return x ** helper.get_exponent()
@env.task
def main(n: int) -> list[int]:
return list(flyte.map(square, range(n)))
if __name__ == "__main__":
import flyte.git
# Initialize with matching root_dir
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent
)
# Run with copy_style="none" and explicit version
run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0" # Explicit version for image tagging
).run(main, n=10)
print(f"Run URL: {run.url}")
run.wait()
```
**Project structure:**
```
project/
βββ full_build.py
βββ dep.py # Local dependency
βββ .flyte/
βββ config.yaml
```
**Run with:**
```bash
python full_build.py
```
This will:
1. Build a container image with `full_build.py` and `dep.py` embedded
2. Tag it as `v1.0.0`
3. Push to registry
4. Execute remotely without code bundling
### Using externally built images
When containers are built outside of Flyte (e.g., in CI/CD), use `Image.from_ref_name()`:
#### Step 1: Build your image externally
```dockerfile
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Copy your code
COPY src/ /app/
# Install dependencies
RUN pip install flyte pandas numpy
# Ensure flyte executable is available
RUN flyte --help
```
```bash
# Build in CI/CD
docker build -t myregistry.com/my-app:v1.2.3 .
docker push myregistry.com/my-app:v1.2.3
```
#### Step 2: Reference image by name
```python
# app.py
import flyte
env = flyte.TaskEnvironment(
name="external",
image=flyte.Image.from_ref_name("my-app-image") # Reference name
)
@env.task
def process(x: int) -> int:
return x * 2
if __name__ == "__main__":
flyte.init_from_config()
# Pass actual image URI at deploy/run time
run = flyte.with_runcontext(
copy_style="none",
images={"my-app-image": "myregistry.com/my-app:v1.2.3"}
).run(process, x=10)
```
Or via CLI:
```bash
flyte run \
--copy-style=none \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py process --x 10
```
**For deployment:**
```bash
flyte deploy \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py
```
#### Why use reference names?
1. **Decouples code from image URIs** - Change images without modifying code
2. **Supports multiple environments** - Different images for dev/staging/prod
3. **Integrates with CI/CD** - Build images in pipelines, reference in code
4. **Enables image reuse** - Multiple tasks can reference the same image
#### Example: Multi-environment deployment
```python
import flyte
import os
# Code references image by name
env = flyte.TaskEnvironment(
name="api",
image=flyte.Image.from_ref_name("api-service")
)
@env.task
def api_call(endpoint: str) -> dict:
# Implementation
return {"status": "success"}
if __name__ == "__main__":
flyte.init_from_config()
# Determine image based on environment
environment = os.getenv("ENV", "dev")
image_uri = {
"dev": "myregistry.com/api-service:dev",
"staging": "myregistry.com/api-service:staging",
"prod": "myregistry.com/api-service:v1.2.3"
}[environment]
run = flyte.with_runcontext(
copy_style="none",
images={"api-service": image_uri}
).run(api_call, endpoint="/health")
```
### Container-based best practices
1. **Always set explicit versions** when using `copy_style="none"`:
```python
flyte.with_runcontext(copy_style="none", version="v1.0.0")
```
2. **Match `root_dir` to `copy_contents_only`**:
- `copy_contents_only=True` β `root_dir=Path(__file__).parent`
- `copy_contents_only=False` β `root_dir=Path(__file__).parent.parent`
3. **Ensure `flyte` executable is in container** - Add to PATH or install flyte package
4. **Use `.dockerignore`** to exclude unnecessary files:
```
# .dockerignore
__pycache__/
*.pyc
.git/
.venv/
*.egg-info/
```
5. **Test containers locally** before deploying:
```bash
docker run -it myimage:latest /bin/bash
python -c "import mymodule" # Verify imports work
```
### When to use container-based deployment
β **Use container-based when:**
- Deploying to production
- Need immutable, reproducible environments
- Working with complex system dependencies
- Deploying to air-gapped or restricted environments
- CI/CD pipelines with automated builds
- Code changes are infrequent
β **Don't use container-based when:**
- Rapid development and frequent code changes
- Quick prototyping
- Interactive development (Jupyter notebooks)
- Learning and experimentation
---
## Choosing the right approach
### Decision tree
```
Are you iterating quickly on code?
ββ Yes β Use Code Bundling (Default)
β (Development, prototyping, notebooks)
β Both approaches are fully reproducible via hash/tag
ββ No β Do you need easy version rollback?
ββ Yes β Use Container-based
β (Production, CI/CD, straightforward tag-based rollback)
ββ No β Either works
(Code bundling is simpler, container-based for air-gapped)
```
### Hybrid approach
You can use different approaches for different tasks:
```python
import flyte
import pathlib
# Fast iteration for development tasks
dev_env = flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base().with_pip_packages("pandas")
# Code bundling (default)
)
# Immutable containers for production tasks
prod_env = flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base()
.with_pip_packages("pandas")
.with_source_folder(pathlib.Path(__file__).parent, copy_contents_only=True)
# Requires copy_style="none"
)
@dev_env.task
def experimental_task(x: int) -> int:
# Rapid development with code bundling
return x * 2
@prod_env.task
def stable_task(x: int) -> int:
# Production with embedded code
return x ** 2
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
# Use code bundling for dev task
dev_run = flyte.run(experimental_task, x=5)
# Use container-based for prod task
prod_run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0"
).run(stable_task, x=5)
```
---
## Troubleshooting
### Import errors
**Problem:** `ModuleNotFoundError` when task executes remotely
**Solutions:**
1. **Check loaded modules** - Ensure modules are imported at module level:
```python
# β Good - bundled automatically
from mymodule import helper
@flyte.task
def my_task():
return helper.process()
```
```python
# β Bad - not bundled (lazy load)
@flyte.task
def my_task():
from mymodule import helper
return helper.process()
```
2. **Verify `root_dir`** matches your import structure:
```python
# If imports are: from mypackage.utils import foo
# Then root_dir should be parent of mypackage/
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
3. **Use `copy_style="all"`** to bundle everything:
```bash
flyte run --copy-style=all app.py my_task
```
### Code changes not reflected
**Problem:** Remote execution uses old code despite local changes
> [!NOTE]
> This is rare with code bundling - Flyte automatically versions based on content hash, so code changes should be detected automatically. This issue typically occurs with caching problems or when using `copy_style="none"`.
**Solutions:**
1. **Use explicit version bump** (mainly for container-based deployments):
```python
run = flyte.with_runcontext(version="v2").run(my_task)
```
2. **Check if `copy_style="none"`** is set - this requires image rebuild:
```python
# If using copy_style="none", rebuild image
run = flyte.with_runcontext(
copy_style="none",
version="v2" # Bump version to force rebuild
).run(my_task)
```
### Files missing in container
**Problem:** Task can't find data files or configs
**Solutions:**
1. **Use `copy_style="all"`** to bundle all files:
```bash
flyte run --copy-style=all app.py my_task
```
2. **Copy files explicitly in image**:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path("config.yaml"),
dst="/app/config.yaml"
)
```
3. **Store data in remote storage** instead of bundling:
```python
@flyte.task
def my_task():
# Read from S3/GCS instead of local files
import flyte.io
data = flyte.io.File("s3://bucket/data.csv").open().read()
```
### Container build failures
**Problem:** Image build fails with `copy_style="none"`
**Solutions:**
1. **Check `root_dir` matches `copy_contents_only`**:
```python
# copy_contents_only=True
image = Image.from_debian_base().with_source_folder(
src=Path(__file__).parent,
copy_contents_only=True
)
flyte.init(root_dir=Path(__file__).parent) # Match!
```
2. **Ensure `flyte` executable available**:
```python
image = Image.from_debian_base() # Has flyte pre-installed
```
3. **Check file permissions** in source directory:
```bash
chmod -R +r project/
```
### Version conflicts
**Problem:** Multiple versions of same image causing confusion
**Solutions:**
1. **Use explicit versions**:
```python
run = flyte.with_runcontext(
copy_style="none",
version="v1.2.3" # Explicit, not auto-generated
).run(my_task)
```
2. **Clean old images**:
```bash
docker image prune -a
```
3. **Use semantic versioning** for clarity:
```python
version = "v1.0.0" # Major.Minor.Patch
```
---
## Further reading
- **Run and deploy tasks > Code packaging for remote execution > Image API Reference** - Complete Image class documentation
- **Run and deploy tasks > Code packaging for remote execution > TaskEnvironment** - Environment configuration options
- [Configuration Guide](./configuration/) - Setting up Flyte config files
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/task-deployment/deployment-patterns ===
# Deployment patterns
Once you understand the basics of task deployment, you can leverage various deployment patterns to handle different project structures, dependency management approaches, and deployment requirements. This guide covers the most common patterns with practical examples.
## Overview of deployment patterns
Flyte supports multiple deployment patterns to accommodate different project structures and requirements:
1. ****Run and deploy tasks > Deployment patterns > Simple file deployment**** - Single file with tasks and environments
2. ****Run and deploy tasks > Deployment patterns > Custom Dockerfile deployment**** - Full control over container environment
3. ****Run and deploy tasks > Deployment patterns > PyProject package deployment**** - Structured Python packages with dependencies and async tasks
4. ****Run and deploy tasks > Deployment patterns > Package structure deployment**** - Organized packages with shared environments
5. ****Run and deploy tasks > Deployment patterns > Full build deployment**** - Complete code embedding in containers
6. ****Run and deploy tasks > Deployment patterns > Python path deployment**** - Multi-directory project structures
7. ****Run and deploy tasks > Deployment patterns > Dynamic environment deployment**** - Environment selection based on domain context
Each pattern serves specific use cases and can be combined as needed for complex projects.
## Simple file deployment
The simplest deployment pattern involves defining both your tasks and task environment in a single Python file. This pattern works well for:
- Prototyping and experimentation
- Simple tasks with minimal dependencies
- Educational examples and tutorials
### Example structure
```python
import flyte
env = flyte.TaskEnvironment(name="simple_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/simple_file.py)
### Deployment commands
```bash
# Deploy the environment
flyte deploy my_example.py env
# Run the task ephemerally
flyte run my_example.py my_task --name "World"
```
### When to use
- Quick prototypes and experiments
- Single-purpose scripts
- Learning Flyte basics
- Tasks with no external dependencies
## Custom Dockerfile deployment
When you need full control over the container environment, you can specify a custom Dockerfile. This pattern is ideal for:
- Complex system dependencies
- Specific OS or runtime requirements
- Custom base images
- Multi-stage builds
### Example structure
```dockerfile
# syntax=docker/dockerfile:1.5
FROM ghcr.io/astral-sh/uv:0.8 as uv
FROM python:3.12-slim-bookworm
USER root
# Copy in uv so that later commands don't have to mount it in
COPY --from=uv /uv /usr/bin/uv
# Configure default envs
ENV UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
VIRTUALENV=/opt/venv \
UV_PYTHON=/opt/venv/bin/python \
PATH="/opt/venv/bin:$PATH"
# Create a virtualenv with the user specified python version
RUN uv venv /opt/venv --python=3.12
WORKDIR /root
# Install dependencies
COPY requirements.txt .
RUN uv pip install --pre -r /root/requirements.txt
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/Dockerfile)
```python
from pathlib import Path
import flyte
env = flyte.TaskEnvironment(
name="docker_env",
image=flyte.Image.from_dockerfile(
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent / "Dockerfile",
registry="ghcr.io/flyteorg",
name="docker_env_image",
),
)
@env.task
def main(x: int) -> int:
return x * 2
if __name__ == "__main__":
import flyte.git
flyte.init_from_config(flyte.git.config_from_root())
run = flyte.run(main, x=10)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/dockerfile_env.py)
### Alternative: Dockerfile in different directory
You can also reference Dockerfiles from subdirectories:
```python
from pathlib import Path
import flyte
env = flyte.TaskEnvironment(
name="docker_env_in_dir",
image=flyte.Image.from_dockerfile(
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent.parent / "Dockerfile.workdir",
registry="ghcr.io/flyteorg",
name="docker_env_image",
),
)
@env.task
def main(x: int) -> int:
return x * 2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main, x=10)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/src/docker_env_in_dir.py)
```dockerfile
# syntax=docker/dockerfile:1.5
FROM ghcr.io/astral-sh/uv:0.8 as uv
FROM python:3.12-slim-bookworm
USER root
# Copy in uv so that later commands don't have to mount it in
COPY --from=uv /uv /usr/bin/uv
# Configure default envs
ENV UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
VIRTUALENV=/opt/venv \
UV_PYTHON=/opt/venv/bin/python \
PATH="/opt/venv/bin:$PATH"
# Create a virtualenv with the user specified python version
RUN uv venv /opt/venv --python=3.12
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN uv pip install --pre -r /app/requirements.txt
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dockerfile/Dockerfile.workdir)
### Key considerations
- **Path handling**: Use `Path(__file__).parent` for relative Dockerfile paths
```python
# relative paths in python change based on where you call, so set it relative to this file
Path(__file__).parent / "Dockerfile"
```
- **Registry configuration**: Specify a registry for image storage
- **Build context**: The directory containing the Dockerfile becomes the build context
- **Flyte installation**: Ensure Flyte is installed in the container and available on `$PATH`
```dockerfile
# Install Flyte in your Dockerfile
RUN pip install flyte
```
- **Dependencies**: Include all application requirements in the Dockerfile or requirements.txt
### When to use
- Need specific system packages or tools
- Custom base image requirements
- Complex installation procedures
- Multi-stage build optimization
## PyProject package deployment
For structured Python projects with proper package management, use the PyProject pattern. This approach demonstrates a **realistic Python project structure** that provides:
- Proper dependency management with `pyproject.toml` and external packages like `httpx`
- Clean separation of business logic and Flyte tasks across multiple modules
- Professional project structure with `src/` layout
- Async task execution with API calls and data processing
- Entrypoint patterns for both command-line and programmatic execution
### Example structure
```
pyproject_package/
βββ pyproject.toml # Project metadata and dependencies
βββ README.md # Documentation
βββ src/
βββ pyproject_package/
βββ __init__.py # Package initialization
βββ main.py # Entrypoint script
βββ data/
β βββ __init__.py
β βββ loader.py # Data loading utilities (no Flyte)
β βββ processor.py # Data processing utilities (no Flyte)
βββ models/
β βββ __init__.py
β βββ analyzer.py # Analysis utilities (no Flyte)
βββ tasks/
βββ __init__.py
βββ tasks.py # Flyte task definitions
```
### Business logic modules
The business logic is completely separate from Flyte and can be used independently:
#### Data Loading (`data/loader.py`)
```python
import json
from pathlib import Path
from typing import Any
import httpx
async def fetch_data_from_api(url: str) -> list[dict[str, Any]]:
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=10.0)
response.raise_for_status()
return response.json()
def load_local_data(file_path: str | Path) -> dict[str, Any]:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
with path.open("r") as f:
return json.load(f)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/data/loader.py)
#### Data Processing (`data/processor.py`)
```python
import asyncio
from typing import Any
from pydantic import BaseModel, Field, field_validator
class DataItem(BaseModel):
id: int = Field(gt=0, description="Item ID must be positive")
value: float = Field(description="Item value")
category: str = Field(min_length=1, description="Item category")
@field_validator("category")
@classmethod
def category_must_be_lowercase(cls, v: str) -> str:
return v.lower()
def clean_data(raw_data: dict[str, Any]) -> dict[str, Any]:
# Remove None values
cleaned = {k: v for k, v in raw_data.items() if v is not None}
# Validate items if present
if "items" in cleaned:
validated_items = []
for item in cleaned["items"]:
try:
validated = DataItem(**item)
validated_items.append(validated.model_dump())
except Exception as e:
print(f"Skipping invalid item {item}: {e}")
continue
cleaned["items"] = validated_items
return cleaned
def transform_data(data: dict[str, Any]) -> list[dict[str, Any]]:
items = data.get("items", [])
# Add computed fields
transformed = []
for item in items:
transformed_item = {
**item,
"value_squared": item["value"] ** 2,
"category_upper": item["category"].upper(),
}
transformed.append(transformed_item)
return transformed
async def aggregate_data(items: list[dict[str, Any]]) -> dict[str, Any]:
# Simulate async processing
await asyncio.sleep(0.1)
aggregated: dict[str, dict[str, Any]] = {}
for item in items:
category = item["category"]
if category not in aggregated:
aggregated[category] = {
"count": 0,
"total_value": 0.0,
"values": [],
}
aggregated[category]["count"] += 1
aggregated[category]["total_value"] += item["value"]
aggregated[category]["values"].append(item["value"])
# Calculate averages
for category, v in aggregated.items():
total = v["total_value"]
count = v["count"]
v["average_value"] = total / count if count > 0 else 0.0
return {"categories": aggregated, "total_items": len(items)}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/data/processor.py)
#### Analysis (`models/analyzer.py`)
```python
from typing import Any
import numpy as np
def calculate_statistics(data: list[dict[str, Any]]) -> dict[str, Any]:
if not data:
return {
"count": 0,
"mean": 0.0,
"median": 0.0,
"std_dev": 0.0,
"min": 0.0,
"max": 0.0,
}
values = np.array([item["value"] for item in data])
stats = {
"count": len(values),
"mean": float(np.mean(values)),
"median": float(np.median(values)),
"std_dev": float(np.std(values)),
"min": float(np.min(values)),
"max": float(np.max(values)),
"percentile_25": float(np.percentile(values, 25)),
"percentile_75": float(np.percentile(values, 75)),
}
return stats
def generate_report(stats: dict[str, Any]) -> str:
report_lines = [
"=" * 60,
"DATA ANALYSIS REPORT",
"=" * 60,
]
# Basic statistics section
if "basic" in stats:
basic = stats["basic"]
report_lines.extend(
[
"",
"BASIC STATISTICS:",
f" Count: {basic.get('count', 0)}",
f" Mean: {basic.get('mean', 0.0):.2f}",
f" Median: {basic.get('median', 0.0):.2f}",
f" Std Dev: {basic.get('std_dev', 0.0):.2f}",
f" Min: {basic.get('min', 0.0):.2f}",
f" Max: {basic.get('max', 0.0):.2f}",
f" 25th %ile: {basic.get('percentile_25', 0.0):.2f}",
f" 75th %ile: {basic.get('percentile_75', 0.0):.2f}",
]
)
# Category aggregations section
if "aggregated" in stats and "categories" in stats["aggregated"]:
categories = stats["aggregated"]["categories"]
total_items = stats["aggregated"].get("total_items", 0)
report_lines.extend(
[
"",
"CATEGORY BREAKDOWN:",
f" Total Items: {total_items}",
"",
]
)
for category, cat_stats in sorted(categories.items()):
report_lines.extend(
[
f" Category: {category.upper()}",
f" Count: {cat_stats.get('count', 0)}",
f" Total Value: {cat_stats.get('total_value', 0.0):.2f}",
f" Average Value: {cat_stats.get('average_value', 0.0):.2f}",
"",
]
)
report_lines.append("=" * 60)
return "\n".join(report_lines)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/models/analyzer.py)
These modules demonstrate:
- **No Flyte dependencies** - can be tested and used independently
- **Pydantic models** for data validation with custom validators
- **Async patterns** with proper context managers and error handling
- **NumPy integration** for statistical calculations
- **Professional error handling** with timeouts and validation
### Flyte orchestration layer
The Flyte tasks orchestrate the business logic with proper async execution:
```python
import pathlib
from typing import Any
import flyte
from pyproject_package.data import loader, processor
from pyproject_package.models import analyzer
UV_PROJECT_ROOT = pathlib.Path(__file__).parent.parent.parent.parent
env = flyte.TaskEnvironment(
name="data_pipeline",
image=flyte.Image.from_debian_base().with_uv_project(pyproject_file=UV_PROJECT_ROOT / "pyproject.toml"),
resources=flyte.Resources(memory="512Mi", cpu="500m"),
)
@env.task
async def fetch_task(url: str) -> list[dict[str, Any]]:
print(f"Fetching data from: {url}")
data = await loader.fetch_data_from_api(url)
print(f"Fetched {len(data)} top-level keys")
return data
@env.task
async def process_task(raw_data: dict[str, Any]) -> list[dict[str, Any]]:
print("Cleaning data...")
cleaned = processor.clean_data(raw_data)
print("Transforming data...")
transformed = processor.transform_data(cleaned)
print(f"Processed {len(transformed)} items")
return transformed
@env.task
async def analyze_task(processed_data: list[dict[str, Any]]) -> str:
print("Aggregating data...")
aggregated = await processor.aggregate_data(processed_data)
print("Calculating statistics...")
stats = analyzer.calculate_statistics(processed_data)
print("Generating report...")
report = analyzer.generate_report({"basic": stats, "aggregated": aggregated})
print("\n" + report)
return report
@env.task
async def pipeline(api_url: str) -> str:
# Chain tasks together
raw_data = await fetch_task(url=api_url)
processed_data = await process_task(raw_data=raw_data[0])
report = await analyze_task(processed_data=processed_data)
return report
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/tasks/tasks.py)
### Entrypoint configuration
The main entrypoint demonstrates proper initialization and execution patterns:
```python
import pathlib
import flyte
from pyproject_package.tasks.tasks import pipeline
def main():
# Initialize Flyte connection
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
# Example API URL with mock data
# In a real scenario, this would be a real API endpoint
example_url = "https://jsonplaceholder.typicode.com/posts"
# For demonstration, we'll use mock data instead of the actual API
# to ensure the example works reliably
print("Starting data pipeline...")
print(f"Target API: {example_url}")
# To run remotely, uncomment the following:
run = flyte.run(pipeline, api_url=example_url)
print(f"\nRun Name: {run.name}")
print(f"Run URL: {run.url}")
run.wait()
if __name__ == "__main__":
main()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/src/pyproject_package/main.py)
### Dependencies and configuration
```toml
[project]
name = "pyproject-package"
version = "0.1.0"
description = "Example Python package with Flyte tasks and modular business logic"
readme = "README.md"
authors = [
{ name = "Ketan Umare", email = "kumare3@users.noreply.github.com" }
]
requires-python = ">=3.10"
dependencies = [
"flyte>=2.0.0b24",
"httpx>=0.27.0",
"numpy>=1.26.0",
"pydantic>=2.0.0",
]
[project.scripts]
run-pipeline = "pyproject_package.main:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pyproject_package/pyproject.toml)
### Key features
- **Async task chains**: Tasks can be chained together with proper async/await patterns
- **External dependencies**: Demonstrates integration with external libraries (`httpx`, `pyyaml`)
- **uv integration**: Uses `.with_uv_project()` for dependency management
- **Resource specification**: Shows how to set memory and CPU requirements
- **Proper error handling**: Includes timeout and error handling in API calls
### Key learning points
1. **Separation of concerns**: Business logic (`data/`, `models/`) separate from orchestration (`main.py`)
2. **Reusable code**: Non-Flyte modules can be tested independently and reused
3. **Async support**: Demonstrates async Flyte tasks for I/O-bound operations
4. **Dependency management**: Shows how external packages integrate with Flyte
5. **Realistic structure**: Mirrors real-world Python project organization
6. **Entrypoint script**: Shows how to create runnable entry points
### Usage patterns
**Run locally:**
```bash
python -m pyproject_package.main
```
**Deploy to Flyte:**
```bash
flyte deploy .
```
**Run remotely:**
```bash
python -m pyproject_package.main # Uses remote execution
```
### What this example demonstrates
- Multiple files and modules in a package
- Async Flyte tasks with external API calls
- Separation of business logic from orchestration
- External dependencies (`httpx`, `numpy`, `pydantic`)
- **Data validation with Pydantic models** for robust data processing
- **Professional error handling** with try/catch for data validation
- **Timeout configuration** for external API calls (`timeout=10.0`)
- **Async context managers** for proper resource management (`async with httpx.AsyncClient()`)
- Entrypoint script pattern with `project.scripts`
- Realistic project structure with `src/` layout
- Task chaining and data flow
- How non-Flyte code integrates with Flyte tasks
### When to use
- Production-ready, maintainable projects
- Projects requiring external API integration
- Complex data processing pipelines
- Team development with proper separation of concerns
- Applications needing async execution patterns
## Package structure deployment
For organizing Flyte workflows in a package structure with shared task environments and utilities, use this pattern. It's particularly useful for:
- Multiple workflows that share common environments and utilities
- Organized code structure with clear module boundaries
- Projects where you want to reuse task environments across workflows
### Example structure
```
lib/
βββ __init__.py
βββ workflows/
βββ __init__.py
βββ workflow1.py # First workflow
βββ workflow2.py # Second workflow
βββ env.py # Shared task environment
βββ utils.py # Shared utilities
```
### Key concepts
- **Shared environments**: Define task environments in `env.py` and import across workflows
- **Utility modules**: Common functions and utilities shared between workflows
- **Root directory handling**: Use `--root-dir` flag for proper Python path configuration
### Running with root directory
When running workflows with a package structure, specify the root directory:
```bash
# Run first workflow
flyte run --root-dir . lib/workflows/workflow1.py process_workflow
# Run second workflow
flyte run --root-dir . lib/workflows/workflow2.py math_workflow --n 6
```
### How `--root-dir` works
The `--root-dir` flag automatically configures the Python path (`sys.path`) to ensure:
1. **Local execution**: Package imports work correctly when running locally
2. **Consistent behavior**: Same Python path configuration locally and at runtime
3. **No manual PYTHONPATH**: Eliminates need to manually export environment variables
4. **Runtime packaging**: Flyte packages and copies code correctly to execution environment
5. **Runtime consistency**: The same package structure is preserved in the runtime container
### Alternative: Using a Python project
For larger projects, create a proper Python project with `pyproject.toml`:
```toml
# pyproject.toml
[project]
name = "lib"
version = "0.1.0"
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
```
Then install in editable mode:
```bash
pip install -e .
```
After installation, you can run workflows without `--root-dir`:
```bash
flyte run lib/workflows/workflow1.py process_workflow
```
However, for deployment and remote execution, still use `--root-dir` for consistency:
```bash
flyte run --root-dir . lib/workflows/workflow1.py process_workflow
flyte deploy --root-dir . lib/workflows/workflow1.py
```
### When to use
- Multiple related workflows in one project
- Shared task environments and utilities
- Team projects with multiple contributors
- Applications requiring organized code structure
- Projects that benefit from proper Python packaging
## Full build deployment
When you need complete reproducibility and want to embed all code directly in the container image, use the full build pattern. This disables Flyte's fast deployment system in favor of traditional container builds.
### Overview
By default, Flyte uses a fast deployment system that:
- Creates a tar archive of your files
- Skips the full image build and push process
- Provides faster iteration during development
However, sometimes you need to **completely embed your code into the container image** for:
- Full reproducibility with immutable container images
- Environments where fast deployment isn't available
- Production deployments with all dependencies baked in
- Air-gapped or restricted deployment environments
### Key configuration
```python
import pathlib
from dep import foo
import flyte
env = flyte.TaskEnvironment(
name="full_build",
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True # Avoid nested folders
),
)
@env.task
def square(x) -> int:
return x ** foo()
@env.task
def main(n: int) -> list[int]:
return list(flyte.map(square, range(n)))
if __name__ == "__main__":
# copy_contents_only=True requires root_dir=parent, False requires root_dir=parent.parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
run = flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/full_build/main.py)
### Local dependency example
The main.py file imports from a local dependency that gets included in the build:
```python
def foo() -> int:
return 1
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/full_build/dep.py)
### Critical configuration components
1. **Set `copy_style` to `"none"`**:
```python
flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
```
This disables Flyte's fast deployment system and forces a full container build.
2. **Set a custom version**:
```python
flyte.with_runcontext(copy_style="none", version="x").run(main, n=10)
```
The `version` parameter should be set to a desired value (not auto-generated) for consistent image tagging.
3. **Configure image source copying**:
```python
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True
)
```
Use `.with_source_folder()` to specify what code to copy into the container.
4. **Set `root_dir` correctly**:
```python
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
- If `copy_contents_only=True`: Set `root_dir` to the source folder (contents are copied)
- If `copy_contents_only=False`: Set `root_dir` to parent directory (folder is copied)
### Configuration options
#### Option A: Copy Folder Structure
```python
# Copies the entire folder structure into the container
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=False # Default
)
# When copy_contents_only=False, set root_dir to parent.parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
#### Option B: Copy Contents Only (Recommended)
```python
# Copies only the contents of the folder (flattens structure)
# This is useful when you want to avoid nested folders - for example all your code is in the root of the repo
image=flyte.Image.from_debian_base().with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True
)
# When copy_contents_only=True, set root_dir to parent
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
### Version management best practices
When using `copy_style="none"`, always specify an explicit version:
- Use semantic versioning: `"v1.0.0"`, `"v1.1.0"`
- Use build numbers: `"build-123"`
- Use git commits: `"abc123"`
Avoid auto-generated versions to ensure reproducible deployments.
### Performance considerations
- **Full builds take longer** than fast deployment
- **Container images will be larger** as they include all source code
- **Better for production** where immutability is important
- **Use during development** when testing the full deployment pipeline
### When to use
β **Use full build when:**
- Deploying to production environments
- Need immutable, reproducible container images
- Working with complex dependency structures
- Deploying to air-gapped or restricted environments
- Building CI/CD pipelines
β **Don't use full build when:**
- Rapid development and iteration
- Working with frequently changing code
- Development environments where speed matters
- Simple workflows without complex dependencies
### Troubleshooting
**Common issues:**
1. **Import errors**: Check your `root_dir` configuration matches `copy_contents_only`
2. **Missing files**: Ensure all dependencies are in the source folder
3. **Version conflicts**: Use explicit, unique version strings
4. **Build failures**: Check that the base image has all required system dependencies
**Debug tips:**
- Add print statements to verify file paths in containers
- Use `docker run -it /bin/bash` to inspect built images
- Check Flyte logs for build errors and warnings
- Verify that relative imports work correctly in the container context
## Python path deployment
For projects where workflows are separated from business logic across multiple directories, use the Python path pattern with proper `root_dir` configuration.
### Example structure
```
pythonpath/
βββ workflows/
β βββ workflow.py # Flyte workflow definitions
βββ src/
β βββ my_module.py # Business logic modules
βββ run.sh # Execute from project root
βββ run_inside_folder.sh # Execute from workflows/ directory
```
### Implementation
```python
import pathlib
from src.my_module import env, say_hello
import flyte
env = flyte.TaskEnvironment(
name="workflow_env",
depends_on=[env],
)
@env.task
async def greet(name: str) -> str:
return await say_hello(name)
if __name__ == "__main__":
current_dir = pathlib.Path(__file__).parent
flyte.init_from_config(root_dir=current_dir.parent)
r = flyte.run(greet, name="World")
print(r.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pythonpath/workflows/workflow.py)
```python
import flyte
env = flyte.TaskEnvironment(
name="my_module",
)
@env.task
async def say_hello(name: str) -> str:
return f"Hello, {name}!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/pythonpath/src/my_module.py)
### Task environment dependencies
Note how the workflow imports both the task environment and the task function:
```python
from src.my_module import env, say_hello
env = flyte.TaskEnvironment(
name="workflow_env",
depends_on=[env], # Depends on the imported environment
)
```
This pattern allows sharing task environments across modules while maintaining proper dependency relationships.
### Key considerations
- **Import resolution**: `root_dir` enables proper module imports across directories
- **File packaging**: Flyte packages all files starting from `root_dir`
- **Execution flexibility**: Works regardless of where you execute the script
- **PYTHONPATH handling**: Different behavior for CLI vs direct Python execution
### CLI vs Direct Python execution
#### Using Flyte CLI with `--root-dir` (Recommended)
When using `flyte run` with `--root-dir`, you don't need to export PYTHONPATH:
```bash
flyte run --root-dir . workflows/workflow.py greet --name "World"
```
The CLI automatically:
- Adds the `--root-dir` location to `sys.path`
- Resolves all imports correctly
- Packages files from the root directory for remote execution
#### Using Python directly
When running Python scripts directly, you must set PYTHONPATH manually:
```bash
PYTHONPATH=.:$PYTHONPATH python workflows/workflow.py
```
This is because:
- Python doesn't automatically know about your project structure
- You need to explicitly tell Python where to find your modules
- The `root_dir` parameter handles remote packaging, not local path resolution
### Best practices
1. **Always set `root_dir`** when workflows import from multiple directories
2. **Use pathlib** for cross-platform path handling
3. **Set `root_dir` to your project root** to ensure all dependencies are captured
4. **Test both execution patterns** to ensure deployment works from any directory
### Common pitfalls
- **Forgetting `root_dir`**: Results in import errors during remote execution
- **Wrong `root_dir` path**: May package too many or too few files
- **Not setting PYTHONPATH when using Python directly**: Use `flyte run --root-dir .` instead
- **Mixing execution methods**: If you use `flyte run --root-dir .`, you don't need PYTHONPATH
### When to use
- Legacy projects with established directory structures
- Separation of concerns between workflows and business logic
- Multiple workflow definitions sharing common modules
- Projects with complex import hierarchies
**Note:** This pattern is an escape hatch for larger projects where code organization requires separating workflows from business logic. Ideally, structure projects with `pyproject.toml` for cleaner dependency management.
## Dynamic environment deployment
For environments that need to change based on deployment context (development vs production), use dynamic environment selection based on Flyte domains.
### Domain-based environment selection
Use `flyte.current_domain()` to deterministically create different task environments based on the deployment domain:
```python
# NOTE: flyte.init() invocation at the module level is strictly discouraged.
# At runtime, Flyte controls initialization and configuration files are not present.
import os
import flyte
def create_env():
if flyte.current_domain() == "development":
return flyte.TaskEnvironment(name="dev", image=flyte.Image.from_debian_base(), env_vars={"MY_ENV": "dev"})
return flyte.TaskEnvironment(name="prod", image=flyte.Image.from_debian_base(), env_vars={"MY_ENV": "prod"})
env = create_env()
@env.task
async def my_task(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['MY_ENV']}", flush=True)
return n + 1
@env.task
async def entrypoint(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['MY_ENV']}", flush=True)
return await my_task(n)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments/environment_picker.py)
### Why this pattern works
**Environment reproducibility in local and remote clusters is critical.** Flyte re-instantiates modules in remote clusters, so `current_domain()` will be set correctly based on where the code executes.
β **Do use `flyte.current_domain()`** - Flyte automatically sets this based on the execution context
β **Don't use environment variables directly** - They won't yield correct results unless manually passed to the downstream system
### How it works
1. Flyte sets the domain context when initializing
2. `current_domain()` returns the domain string (e.g., "development", "staging", "production")
3. Your code deterministically configures resources based on this domain
4. When Flyte executes remotely, it re-instantiates modules with the correct domain context
5. The same environment configuration logic runs consistently everywhere
### Important constraints
`flyte.current_domain()` only works **after** `flyte.init()` is called:
- β Works with `flyte run` and `flyte deploy` CLI commands (they init automatically)
- β Works when called from `if __name__ == "__main__"` after explicit `flyte.init()`
- β Does NOT work at module level without initialization
**Critical:** `flyte.init()` invocation at the module level is **strictly discouraged**. The reason is that at runtime, Flyte controls the initialization and configuration files are not present at runtime.
### Alternative: Environment variable approach
For cases where you need to pass domain information as environment variables to the container runtime, use this approach:
```python
import os
import flyte
def create_env(domain: str):
# Pass domain as environment variable so tasks can see which domain they're running in
if domain == "development":
return flyte.TaskEnvironment(name="dev", image=flyte.Image.from_debian_base(), env_vars={"DOMAIN_NAME": domain})
return flyte.TaskEnvironment(name="prod", image=flyte.Image.from_debian_base(), env_vars={"DOMAIN_NAME": domain})
env = create_env(os.getenv("DOMAIN_NAME", "development"))
@env.task
async def my_task(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['DOMAIN_NAME']}", flush=True)
return n + 1
@env.task
async def entrypoint(n: int) -> int:
print(f"Environment Variable MY_ENV = {os.environ['DOMAIN_NAME']}", flush=True)
return await my_task(n)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(entrypoint, n=5)
print(r.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments_with_envvars/environment_picker.py)
#### Key differences from domain-based approach
- **Environment variable access**: The domain name is available inside tasks via `os.environ['DOMAIN_NAME']`
- **External control**: Can be controlled via system environment variables before execution
- **Runtime visibility**: Tasks can inspect which environment they're running in during execution
- **Default fallback**: Uses `"development"` as default when `DOMAIN_NAME` is not set
#### Usage with environment variables
```bash
# Set environment and run
export DOMAIN_NAME=production
flyte run environment_picker.py entrypoint --n 5
# Or set inline
DOMAIN_NAME=development flyte run environment_picker.py entrypoint --n 5
```
#### When to use environment variables vs domain-based
**Use environment variables when:**
- Tasks need runtime access to environment information
- External systems set environment configuration
- You need flexibility to override environment externally
- Debugging requires visibility into environment selection
**Use domain-based approach when:**
- Environment selection should be automatic based on Flyte domain
- You want tighter integration with Flyte's domain system
- No need for runtime environment inspection within tasks
You can vary multiple aspects based on context:
- **Base images**: Different images for dev vs prod
- **Environment variables**: Configuration per environment
- **Resource requirements**: Different CPU/memory per domain
- **Dependencies**: Different package versions
- **Registry settings**: Different container registries
### Usage patterns
```bash
# CLI usage (recommended)
flyte run environment_picker.py entrypoint --n 5
flyte deploy environment_picker.py
```
For programmatic usage, ensure proper initialization:
```python
import flyte
flyte.init_from_config()
from environment_picker import entrypoint
if __name__ == "__main__":
r = flyte.run(entrypoint, n=5)
print(r.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-deployment/deployment-patterns/dynamic_environments/main.py)
### When to use dynamic environments
**General use cases:**
- Multi-environment deployments (dev/staging/prod)
- Different resource requirements per environment
- Environment-specific dependencies or settings
- Context-sensitive configuration needs
**Domain-based approach for:**
- Automatic environment selection tied to Flyte domains
- Simpler configuration without external environment variables
- Integration with Flyte's built-in domain system
**Environment variable approach for:**
- Runtime visibility into environment selection within tasks
- External control over environment configuration
- Debugging and logging environment-specific behavior
- Integration with external deployment systems that set environment variables
## Best practices
### Project organization
1. **Separate concerns**: Keep business logic separate from Flyte task definitions
2. **Use proper imports**: Structure projects for clean import patterns
3. **Version control**: Include all necessary files in version control
4. **Documentation**: Document deployment requirements and patterns
### Image management
1. **Registry configuration**: Use consistent registry settings across environments
2. **Image tagging**: Use meaningful tags for production deployments
3. **Base image selection**: Choose appropriate base images for your needs
4. **Dependency management**: Keep container images lightweight but complete
### Configuration management
1. **Root directory**: Set `root_dir` appropriately for your project structure
2. **Path handling**: Use `pathlib.Path` for cross-platform compatibility
3. **Environment variables**: Use environment-specific configurations
4. **Secrets management**: Handle sensitive data appropriately
### Development workflow
1. **Local testing**: Test tasks locally before deployment
2. **Incremental development**: Use `flyte run` for quick iterations
3. **Production deployment**: Use `flyte deploy` for permanent deployments
4. **Monitoring**: Monitor deployed tasks and environments
## Choosing the right pattern
| Pattern | Use Case | Complexity | Best For |
|---------|----------|------------|----------|
| Simple file | Quick prototypes, learning | Low | Single tasks, experiments |
| Custom Dockerfile | System dependencies, custom environments | Medium | Complex dependencies |
| PyProject package | Professional projects, async pipelines | Medium-High | Production applications |
| Package structure | Multiple workflows, shared utilities | Medium | Organized team projects |
| Full build | Production, reproducibility | High | Immutable deployments |
| Python path | Legacy structures, separated concerns | Medium | Existing codebases |
| Dynamic environment | Multi-environment, domain-aware deployments | Medium | Context-aware deployments |
Start with simpler patterns and evolve to more complex ones as your requirements grow. Many projects will combine multiple patterns as they scale and mature.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling ===
# Scale your runs
This guide helps you understand and optimize the performance of your Flyte workflows. Whether you're building latency-sensitive applications or high-throughput data pipelines, these docs will help you make the right architectural choices.
## Understanding Flyte execution
Before optimizing performance, it's important to understand how Flyte executes your workflows:
- ****Scale your runs > Data flow****: Learn how data moves between tasks, including inline vs. reference data types, caching mechanisms, and storage configuration.
- ****Scale your runs > Life of a run****: Understand what happens when you invoke `flyte.run()`, from code analysis and image building to task execution and state management.
## Performance optimization
Once you understand the fundamentals, dive into performance tuning:
- ****Scale your runs > Scale your workflows****: A comprehensive guide to optimizing workflow performance, covering latency vs. throughput, task overhead analysis, batching strategies, reusable containers, and more.
## Key concepts for scaling
When scaling your workflows, keep these principles in mind:
1. **Task overhead matters**: The overhead of creating a task (uploading data, enqueuing, creating containers) should be much smaller than the task runtime.
2. **Batch for throughput**: For large-scale data processing, batch multiple items into single tasks to reduce overhead.
3. **Reusable containers**: Eliminate container startup overhead and enable concurrent execution with reusable containers.
4. **Traces for lightweight ops**: Use traces instead of tasks for lightweight operations that need checkpointing.
5. **Limit fanout**: Keep the total number of actions per run below 50k (target 10k-20k for best performance).
6. **Choose the right data types**: Use reference types (files, directories, DataFrames) for large data and inline types for small data.
For detailed guidance on each of these topics, see **Scale your runs > Scale your workflows**.
## Subpages
- **Scale your runs > Data flow**
- **Scale your runs > Life of a run**
- **Scale your runs > Scale your workflows**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/data-flow ===
# Data flow
Understanding how data flows between tasks is critical for optimizing workflow performance in Flyte. Tasks take inputs and produce outputs, with data flowing seamlessly through your workflow using an efficient transport layer.
## Overview
Flyte tasks are run to completion. Each task takes inputs and produces exactly one output. Even if multiple instances run concurrently (such as in retries), only one output will be accepted. This deterministic data flow model provides several key benefits:
1. **Reduced boilerplate**: Automatic handling of files, DataFrames, directories, custom types, data classes, Pydantic models, and primitive types without manual serialization.
2. **Type safety**: Optional type annotations enable deeper type understanding, automatic UI form generation, and runtime type validation.
3. **Efficient transport**: Data is passed by reference (files, directories, DataFrames) or by value (primitives) based on type.
4. **Durable storage**: All data is stored durably and accessible through APIs and the UI.
5. **Caching support**: Efficient caching using shallow immutable references for referenced data.
## Data types and transport
Flyte handles different data types with different transport mechanisms:
### Passed by reference
These types are not copied but passed as references to storage locations:
- **Files**: `flyte.io.File`
- **Directories**: `flyte.io.Directory`
- **Dataframes**: `flyte.io.DataFrame`, `pd.DataFrame`, `pl.DataFrame`, etc.
Dataframes are automatically converted to Parquet format and read using Arrow for zero-copy reads. Use `flyte.io.DataFrame` for lazy materialization to any supported type like pandas or polars.
### Passed by value (inline I/O)
Primitive and structured types are serialized and passed inline:
| Type Category | Examples | Serialization |
|--------------|----------|---------------|
| **Primitives** | `int`, `float`, `str`, `bool`, `None` | MessagePack |
| **Time types** | `datetime.datetime`, `datetime.date`, `datetime.timedelta` | MessagePack |
| **Collections** | `list`, `dict`, `tuple` | MessagePack |
| **Data structures** | data classes, Pydantic `BaseModel` | MessagePack |
| **Enums** | `enum.Enum` subclasses | MessagePack |
| **Unions** | `Union[T1, T2]`, `Optional[T]` | MessagePack |
| **Protobuf** | `google.protobuf.Message` | Binary |
Flyte uses efficient MessagePack serialization for most types, providing compact binary representation with strong type safety.
> [!NOTE]
> If type annotations are not used, or if `typing.Any` or unrecognized types are used, data will be pickled. By default, picked objects smaller than 10KB are passed inline, while larger picked objects are automatically passed as a file. Pickling allows for progressive typing but should be used carefully.
## Task execution and data flow
### Input download
When a task starts:
1. **Inline inputs download**: The task downloads inline inputs from the configured Flyte object store.
2. **Size limits**: By default, inline inputs are limited to 10MB, but this can be adjusted using `flyte.TaskEnvironment`'s `max_inline_io` parameter.
3. **Memory consideration**: Inline data is materialized in memory, so adjust your task resources accordingly.
4. **Reference materialization**: Reference data (files, directories) is passed using special types in `flyte.io`. Dataframes are automatically materialized if using `pd.DataFrame`. Use `flyte.io.DataFrame` to avoid automatic materialization.
### Output upload
When a task returns data:
1. **Inline data**: Uploaded to the Flyte object store configured at the organization, project, or domain level.
2. **Reference data**: Stored in the same metadata store by default, or configured using `flyte.with_runcontext(raw_data_storage=...)`.
3. **Separate prefixes**: Each task creates one output per retry attempt in separate prefixes, making data incorruptible by design.
## Task-to-task data flow
When a task invokes downstream tasks:
1. **Input recording**: The input to the downstream task is recorded to the object store.
2. **Reference upload**: All referenced objects are uploaded (if not already present).
3. **Task invocation**: The downstream task is invoked on the remote server.
4. **Parallel execution**: When multiple tasks are invoked in parallel using `flyte.map` or `asyncio`, inputs are written in parallel.
5. **Storage layer**: Data writing uses the `flyte.storage` layer, backed by the Rust-based `object-store` crate and optionally `fsspec` plugins.
6. **Output download**: Once the downstream task completes, inline outputs are downloaded and returned to the calling task.
## Caching and data hashing
Understanding how Flyte caches data is essential for performance optimization.
### Cache key computation
A cache hit occurs when the following components match:
- **Task name**: The fully-qualified task name
- **Computed input hash**: Hash of all inputs (excluding `ignored_inputs`)
- **Task interface hash**: Hash of input and output types
- **Task config hash**: Hash of task configuration
- **Cache version**: User-specified or automatically computed
### Inline data caching
All inline data is cached using a consistent hashing system. The cache key is derived from the data content.
### Reference data hashing
Reference data (files, directories) is hashed shallowly by default using the hash of the storage location. You can customize hashing:
- Use `flyte.io.File.new_remote()` or `flyte.io.File.from_existing_remote()` with custom hash functions or values.
- Provide explicit hash values for deep content hashing if needed.
### Cache control
Control caching behavior using `flyte.with_runcontext`:
- **Scope**: Set `cache_lookup_scope` to `"global"` or `"project/domain"`.
- **Disable cache**: Set `overwrite_cache=True` to force re-execution.
For more details on caching configuration, see **Configure tasks > Caching**.
## Traces and data flow
When using **Build tasks > Traces**, the data flow behavior is different:
1. **Full execution first**: The trace is fully executed before inputs and outputs are recorded.
2. **Checkpoint behavior**: Recording happens like a checkpoint at the end of trace execution.
3. **Streaming iterators**: The entire output is buffered and recorded after the stream completes. Buffering is pass-through, allowing caller functions to consume output while buffering.
4. **Chained traces**: All traces are recorded after the last one completes consumption.
5. **Same process with `asyncio`**: Traces run within the same Python process and support `asyncio` parallelism, so failures can be retried, effectively re-running the trace.
6. **Lightweight overhead**: Traces only have the overhead of data storage (no task orchestration overhead).
> [!NOTE]
> Traces are not a substitute for tasks if you need caching. Tasks provide full caching capabilities, while traces provide lightweight checkpointing with storage overhead. However, traces support concurrent execution using `asyncio` patterns within a single task.
## Object stores and latency considerations
By default, Flyte uses object stores like S3, GCS, Azure Storage, and R2 as metadata stores. These have high latency for smaller objects, so:
- **Minimum task duration**: Tasks should take at least a second to run to amortize storage overhead.
- **Future improvements**: High-performance metastores like Redis and PostgreSQL may be supported in the future. Contact the Union team if you're interested.
## Configuring data storage
### Organization and project level
Object stores are configured at the organization level or per project/domain. Documentation for this configuration is coming soon.
### Per-run configuration
Configure raw data storage on a per-run basis using `flyte.with_runcontext`:
```python
run = flyte.with_runcontext(
raw_data_storage="s3://my-bucket/custom-path"
).run(my_task, input_data=data)
```
This allows you to control where reference data (files, directories, DataFrames) is stored for specific runs.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/life-of-a-run ===
# Life of a run
Understanding what happens when you invoke `flyte.run()` is crucial for optimizing workflow performance and debugging issues. This guide walks through each phase of task execution from submission to completion.
## Overview
When you execute `flyte.run()`, the system goes through several phases:
1. **Code analysis and preparation**: Discover environments and images
2. **Image building**: Build container images if changes are detected
3. **Code bundling**: Package your Python code
4. **Upload**: Transfer the code bundle to object storage
5. **Run creation**: Submit the run to the backend
6. **Task execution**: Execute the task in the data plane
7. **State management**: Track and persist execution state
## Phase 1: Code analysis and preparation
When `flyte.run()` is invoked:
1. **Environment discovery**: Flyte analyzes your code and finds all relevant `flyte.TaskEnvironment` instances by walking the `depends_on` hierarchy.
2. **Image identification**: Discovers unique `flyte.Image` instances used across all environments.
3. **Image building**: Starts the image building process. Images are only built if a change is detected.
> [!NOTE]
> If you invoke `flyte.run()` multiple times within the same Python process without changing code (such as in a notebook or script), the code bundling and image building steps are done only once. This can dramatically speed up iteration.
## Phase 2: Image building
Container images provide the runtime environment for your tasks:
- **Change detection**: Images are only rebuilt if changes are detected in dependencies or configuration.
- **Caching**: Previously built images are reused when possible.
- **Parallel builds**: Multiple images can be built concurrently.
For more details on container images, see **Configure tasks > Container images**.
## Phase 3: Code bundling
After images are built, your project files are bundled:
### Default: `copy_style="loaded_modules"`
By default, all Python modules referenced by the invoked tasks through module-level import statements are automatically copied. This provides a good balance between completeness and efficiency.
### Alternative: `copy_style="none"`
Skip bundling by setting `copy_style="none"` in `flyte.with_runcontext()` and adding all code into `flyte.Image`:
```python
# Add code to image
image = flyte.Image().with_source_code("/path/to/code")
# Or use Dockerfile
image = flyte.Image.from_dockerfile("Dockerfile")
# Skip bundling
run = flyte.with_runcontext(copy_style="none").run(my_task, input_data=data)
```
For more details on code packaging, see **Run and deploy tasks > Code packaging for remote execution**.
## Phase 4: Upload code bundle
Once the code bundle is created:
1. **Negotiate signed URL**: The SDK requests a signed URL from the backend.
2. **Upload**: The code bundle is uploaded to the signed URL location in object storage.
3. **Reference stored**: The backend stores a reference to the uploaded bundle.
## Phase 5: Run creation and queuing
The `CreateRun` API is invoked:
1. **Copy inputs**: Input data is copied to the object store.
2. **En-queue a run**: The run is queued into the Union Control Plane.
3. **Hand off to executor**: Union Control Plane hands the task to the Executor Service in your data plane.
4. **Create action**: The parent task action (called `a0`) is created.
## Phase 6: Task execution in data plane
### Container startup
1. **Container starts**: The task container starts in your data plane.
2. **Download code bundle**: The Flyte runtime downloads the code bundle from object storage.
3. **Inflate task**: The task is inflated from the code bundle.
4. **Download inputs**: Inline inputs are downloaded from the object store.
5. **Execute task**: The task is executed with context and inputs.
### Invoking downstream tasks
If the task invokes other tasks:
1. **Controller thread**: A controller thread starts to communicate with the backend Queue Service.
2. **Monitor status**: The controller monitors the status of downstream actions.
3. **Crash recovery**: If the task crashes, the action identifier is deterministic, allowing the task to resurrect its state from Union Control Plane.
4. **Replay**: The controller efficiently replays state (even at large scale) to find missing completions and resume monitoring.
### Execution flow diagram
```mermaid
%%{init: {'theme':'base', 'themeVariables': { 'primaryColor':'#1f2937', 'primaryTextColor':'#e5e7eb', 'primaryBorderColor':'#6b7280', 'lineColor':'#9ca3af', 'secondaryColor':'#374151', 'tertiaryColor':'#1f2937', 'actorBorder':'#6b7280', 'actorTextColor':'#e5e7eb', 'signalColor':'#9ca3af', 'signalTextColor':'#e5e7eb'}}}%%
sequenceDiagram
participant Client as SDK/Client
participant Control as Control Plane (Queue Service)
participant Data as Data Plane (Executor)
participant ObjStore as Object Store
participant Container as Task Container
Client->>Client: Analyze code & discover environments
Client->>Client: Build images (if changed)
Client->>Client: Bundle code
Client->>Control: Upload code bundle
Control->>Data: Store code bundle
Data->>ObjStore: Write code bundle
Client->>Control: CreateRun API with inputs
Control->>Data: Copy inputs
Data->>ObjStore: Write inputs
Control->>Data: Queue task (create action a0)
Data->>Container: Start container
Container->>Data: Request code bundle
Data->>ObjStore: Read code bundle
ObjStore-->>Data: Code bundle
Data-->>Container: Code bundle
Container->>Container: Inflate task
Container->>Data: Request inputs
Data->>ObjStore: Read inputs
ObjStore-->>Data: Inputs
Data-->>Container: Inputs
Container->>Container: Execute task
alt Invokes downstream tasks
Container->>Container: Start controller thread
Container->>Control: Submit downstream tasks
Control->>Data: Queue downstream actions
Container->>Control: Monitor downstream status
Control-->>Container: Status updates
end
Container->>Data: Upload outputs
Data->>ObjStore: Write outputs
Container->>Control: Complete
Control-->>Client: Run complete
```
## Action identifiers and crash recovery
Flyte uses deterministic action identifiers to enable robust crash recovery:
- **Consistent identifiers**: Action identifiers are consistently computed based on task and invocation context.
- **Re-run identical**: In any re-run, the action identifier is identical for the same invocation.
- **Multiple invocations**: Multiple invocations of the same task receive unique identifiers.
- **Efficient resurrection**: On crash, the `a0` action resurrects its state from Union Control Plane efficiently, even at large scale.
- **Replay and resume**: The controller replays execution until it finds missing completions and starts watching them.
## Downstream task execution
When downstream tasks are invoked:
1. **Action creation**: Downstream actions are created with unique identifiers.
2. **Queue assignment**: Actions are handed to an executor, which can be selected using a queue or from the general pool.
3. **Parallel execution**: Multiple downstream tasks can execute in parallel.
4. **Result aggregation**: Results are aggregated and returned to the parent task.
## Reusable containers
When using **Configure tasks > Reusable containers**, the execution model changes:
1. **Environment spin-up**: The container environment is first spun up with configured replicas.
2. **Task allocation**: Tasks are allocated to available replicas in the environment.
3. **Scaling**: If all replicas are busy, new replicas are spun up (up to the configured maximum), or tasks are backlogged in queues.
4. **Container reuse**: The same container handles multiple task executions, reducing startup overhead.
5. **Lifecycle management**: Containers are managed according to `ReusePolicy` settings (`idle_ttl`, `scaledown_ttl`, etc.).
### Reusable container execution flow
```mermaid
%%{init: {'theme':'base', 'themeVariables': { 'primaryColor':'#1f2937', 'primaryTextColor':'#e5e7eb', 'primaryBorderColor':'#6b7280', 'lineColor':'#9ca3af', 'secondaryColor':'#374151', 'tertiaryColor':'#1f2937', 'actorBorder':'#6b7280', 'actorTextColor':'#e5e7eb', 'signalColor':'#9ca3af', 'signalTextColor':'#e5e7eb'}}}%%
sequenceDiagram
participant Control as Queue Service
participant Executor as Executor Service
participant Pool as Container Pool
participant Replica as Container Replica
Control->>Executor: Submit task
alt Reusable containers enabled
Executor->>Pool: Request available replica
alt Replica available
Pool->>Replica: Allocate task
Replica->>Replica: Execute task
Replica->>Pool: Task complete (ready for next)
else No replica available
alt Can scale up
Executor->>Pool: Create new replica
Pool->>Replica: Spin up new container
Replica->>Replica: Execute task
Replica->>Pool: Task complete
else At max replicas
Executor->>Pool: Queue task
Pool-->>Executor: Wait for available replica
Pool->>Replica: Allocate when available
Replica->>Replica: Execute task
Replica->>Pool: Task complete
end
end
else No reusable containers
Executor->>Replica: Create new container
Replica->>Replica: Execute task
Replica->>Executor: Complete & terminate
end
Replica-->>Control: Return results
```
## State replication and visualization
### Queue Service to Run Service
1. **Reliable replication**: Queue Service reliably replicates execution state back to Run Service.
2. **Eventual consistency**: The Run Service may be slightly behind the actual execution state.
3. **Visualization**: Run Service paints the entire run onto the UI.
### UI limitations
- **Current limit**: The UI is currently limited to displaying 50k actions per run.
- **Future improvements**: This limit will be increased in future releases. Contact the Union team if you need higher limits.
## Optimization opportunities
Understanding the life of a run reveals several optimization opportunities:
1. **Reuse Python process**: Run `flyte.run()` multiple times in the same process to avoid re-bundling code.
2. **Skip bundling**: Use `copy_style="none"` and bake code into images for faster startup.
3. **Reusable containers**: Use reusable containers to eliminate container startup overhead.
4. **Parallel execution**: Invoke multiple downstream tasks concurrently using `flyte.map()` or `asyncio`.
5. **Efficient data flow**: Minimize data transfer by using reference types (files, directories) instead of inline data.
6. **Caching**: Enable task caching to avoid redundant computation.
For detailed performance tuning guidance, see [Scale your workflows](performance).
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/run-scaling/scale-your-workflows ===
# Scale your workflows
Performance optimization in Flyte involves understanding the interplay between task execution overhead, data transfer, and concurrency. This guide helps you identify bottlenecks and choose the right patterns for your workload.
## Understanding performance dimensions
Performance optimization focuses on two key dimensions:
### Latency
**Goal**: Minimize end-to-end execution time for individual workflows.
**Characteristics**:
- Fast individual actions (milliseconds to seconds)
- Total action count typically less than 1,000
- Critical for interactive applications and real-time processing
- Multi-step inference, with reusing model or data in memory (use reusable containers with @alru.cache)
**Recommended approach**:
- Use tasks for orchestration and parallelism
- Use **Build tasks > Traces** for fine-grained checkpointing
- Model parallelism using asyncio and use things like `asyncio.as_completed` or `asyncio.gather` to join the parallelism
- Leverage **Configure tasks > Reusable containers** with concurrency to eliminate startup overhead
### Throughput
**Goal**: Maximize the number of items processed per unit time.
**Characteristics**:
- Processing large datasets (millions of items)
- High total action count (10k to 50k actions)
- Batch processing, Large scale batch inference and ETL workflows
**Recommended approach**:
- Batch workloads to reduce overhead
- Limit fanout to manage system load
- Use reusable containers with concurrency for maximum utilization
- Balance task granularity with overhead
## Task execution overhead
Understanding task overhead is critical for performance optimization. When you invoke a task, several operations occur:
| Operation | Symbol | Description |
|-----------|--------|-------------|
| **Upload data** | `u` | Time to upload input data to object store |
| **Download data** | `d` | Time to download input data from object store |
| **Enqueue task** | `e` | Time to enqueue task in Queue Service |
| **Create instance** | `t` | Time to create task container instance |
**Total overhead per task**: `2u + 2d + e + t`
This overhead includes:
- Uploading inputs from the parent task (`u`)
- Downloading inputs in the child task (`d`)
- Uploading outputs from the child task (`u`)
- Downloading outputs in the parent task (`d`)
- Enqueuing the task (`e`)
- Creating the container instance (`t`)
### The overhead principle
For efficient execution, task overhead should be much smaller than task runtime:
```
Total overhead (2u + 2d + e + t) << Task runtime
```
If task runtime is comparable to or less than overhead, consider:
1. **Batching**: Combine multiple work items into a single task
2. **Traces**: Use traces instead of tasks for lightweight operations
3. **Reusable containers**: Eliminate container creation overhead (`t`)
4. **Local execution**: Run lightweight operations within the parent task
## System architecture and data flow
To optimize performance, understand how tasks flow through the system:
1. **Control plane to data plane**: Tasks flow from the control plane (Run Service, Queue Service) to the data plane (Executor Service).
2. **Data movement**: Data moves between tasks through object storage. See **Scale your runs > Data flow** for details.
3. **State replication**: Queue Service reliably replicates state back to Run Service for visualization. The Run Service may be slightly behind actual execution.
For a detailed walkthrough of task execution, see **Scale your runs > Life of a run**.
## Optimization strategies
### 1. Use reusable containers for concurrency
**Configure tasks > Reusable containers** eliminate the container creation overhead (`t`) and enable concurrent task execution:
```python
import flyte
from datetime import timedelta
# Define reusable environment
env = flyte.TaskEnvironment(
name="high-throughput",
reuse_policy=flyte.ReusePolicy(
replicas=(2, 10), # Auto-scale from 2 to 10 replicas
concurrency=5, # 5 tasks per replica = 50 max concurrent
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
)
@env.task
async def process_item(item: dict) -> dict:
# Process individual item
return {"processed": item["id"]}
```
**Benefits**:
- Eliminates container startup overhead (`t β 0`)
- Supports concurrent execution (multiple tasks per container)
- Auto-scales based on demand
- Reuses Python environment and loaded dependencies
**Limitations**:
- Concurrency is limited by CPU and I/O resources in the container
- Memory requirements scale with total working set size
- Best for I/O-bound tasks or async operations
### 2. Batch workloads to reduce overhead
For high-throughput processing, batch multiple items into a single task:
```python
@env.task
async def process_batch(items: list[dict]) -> list[dict]:
"""Process a batch of items in a single task."""
results = []
for item in items:
result = await process_single_item(item)
results.append(result)
return results
@env.task
async def process_large_dataset(dataset: list[dict]) -> list[dict]:
"""Process 1M items with batching."""
batch_size = 1000 # Adjust based on overhead calculation
batches = [dataset[i:i + batch_size] for i in range(0, len(dataset), batch_size)]
# Process batches in parallel (1000 tasks instead of 1M)
results = await asyncio.gather(*[process_batch(batch) for batch in batches])
# Flatten results
return [item for batch_result in results for item in batch_result]
```
**Benefits**:
- Reduces total number of tasks (e.g., 1000 tasks instead of 1M)
- Amortizes overhead across multiple items
- Lower load on Queue Service and object storage
**Choosing batch size**:
1. Calculate overhead: `overhead = 2u + 2d + e + t`
2. Target task runtime: `runtime > 10 Γ overhead` (rule of thumb)
3. Adjust batch size to achieve target runtime
4. Consider memory constraints (larger batches require more memory)
### 3. Use traces for lightweight operations
**Build tasks > Traces** provide fine-grained checkpointing with minimal overhead:
```python
@flyte.trace
async def fetch_data(url: str) -> dict:
"""Traced function for API call."""
response = await http_client.get(url)
return response.json()
@flyte.trace
async def transform_data(data: dict) -> dict:
"""Traced function for transformation."""
return {"transformed": data}
@env.task
async def process_workflow(urls: list[str]) -> list[dict]:
"""Orchestrate using traces instead of tasks."""
results = []
for url in urls:
data = await fetch_data(url)
transformed = await transform_data(data)
results.append(transformed)
return results
```
**Benefits**:
- Only storage overhead (no task orchestration overhead)
- Runs in the same Python process with asyncio parallelism
- Provides checkpointing and resumption
- Visible in execution logs and UI
**Trade-offs**:
- No caching (use tasks for cacheable operations)
- Shares resources with the parent task (CPU, memory)
- Storage writes may still be slow due to object store latency
**When to use traces**:
- API calls and external service interactions
- Deterministic transformations that need checkpointing
- Operations taking more than 1 second (to amortize storage overhead)
### 4. Limit fanout for system stability
The UI and system have limits on the number of actions per run:
- **Current limit**: 50k actions per run
- **Future**: Higher limits will be supported (contact the Union team if needed)
**Example: Control fanout with batching**
```python
@env.task
async def process_million_items(items: list[dict]) -> list[dict]:
"""Process 1M items with controlled fanout."""
# Target 10k tasks, each processing 100 items
batch_size = 100
max_fanout = 10000
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
# Use flyte.map for parallel execution
results = await flyte.map(process_batch, batches)
return [item for batch in results for item in batch]
```
### 5. Optimize data transfer
Minimize data transfer overhead by choosing appropriate data types:
**Use reference types for large data**:
```python
from flyte.io import File, Directory, DataFrame
@env.task
async def process_large_file(input_file: File) -> File:
"""Files passed by reference, not copied."""
# Download only when needed
local_path = input_file.download()
# Process file
result_path = process(local_path)
# Upload result
return File.new_remote(result_path)
```
**Use inline types for small data**:
```python
@env.task
async def process_metadata(metadata: dict) -> dict:
"""Small dicts passed inline efficiently."""
return {"processed": metadata}
```
**Guideline**:
- **< 10 MB**: Use inline types (primitives, small dicts, lists)
- **> 10 MB**: Use reference types (File, Directory, DataFrame)
- **Adjust**: Use `max_inline_io` in `TaskEnvironment` to change the threshold
See **Scale your runs > Data flow** for details on data types and transport.
### 6. Leverage caching
Enable **Configure tasks > Caching** to avoid redundant computation:
```python
@env.task(cache="auto")
async def expensive_computation(input_data: dict) -> dict:
"""Automatically cached based on inputs."""
# Expensive operation
return result
```
**Benefits**:
- Skips re-execution for identical inputs
- Reduces overall workflow runtime
- Preserves resources for new computations
**When to use**:
- Deterministic tasks (same inputs β same outputs)
- Expensive computations (model training, large data processing)
- Stable intermediate results
### 7. Parallelize with `flyte.map`
Use **Build tasks > Fanout** for data-parallel workloads:
```python
@env.task
async def process_item(item: dict) -> dict:
return {"processed": item}
@env.task
async def parallel_processing(items: list[dict]) -> list[dict]:
"""Process items in parallel using map."""
results = await flyte.map(process_item, items)
return results
```
**Benefits**:
- Automatic parallelization
- Dynamic scaling based on available resources
- Built-in error handling and retries
**Best practices**:
- Combine with batching to control fanout
- Use with reusable containers for maximum throughput
- Consider memory and resource limits
## Performance tuning workflow
Follow this workflow to optimize your Flyte workflows:
1. **Profile**: Measure task execution times and identify bottlenecks.
2. **Calculate overhead**: Estimate `2u + 2d + e + t` for your tasks.
3. **Compare**: Check if `task runtime >> overhead`. If not, optimize.
4. **Batch**: Increase batch size to amortize overhead.
5. **Reusable containers**: Enable reusable containers to eliminate `t`.
6. **Traces**: Use traces for lightweight operations within tasks.
7. **Cache**: Enable caching for deterministic, expensive tasks.
8. **Limit fanout**: Keep total actions below 50k (target 10k-20k).
9. **Monitor**: Use the UI to monitor execution and identify issues.
10. **Iterate**: Continuously refine based on performance metrics.
## Real-world example: PyIceberg batch processing
For a comprehensive example of efficient data processing with Flyte, see the [PyIceberg parallel batch aggregation example](https://github.com/unionai/flyte-sdk/blob/main/examples/data_processing/pyiceberg_example.py). This example demonstrates:
- **Zero-copy data passing**: Pass file paths instead of data between tasks
- **Reusable containers with concurrency**: Maximize CPU utilization across workers
- **Parallel file processing**: Use `asyncio.gather()` to process multiple files concurrently
- **Efficient batching**: Distribute parquet files across worker tasks
Key pattern from the example:
```python
# Instead of loading entire table, get file paths
file_paths = [task.file.file_path for task in table.scan().plan_files()]
# Distribute files across partitions (zero-copy!)
partition_files = distribute_files(file_paths, num_partitions)
# Process partitions in parallel
results = await asyncio.gather(*[
aggregate_partition(files, partition_id)
for partition_id, files in enumerate(partition_files)
])
```
This approach achieves true parallel file processing without loading the entire dataset into memory.
## Example: Optimizing a data pipeline
### Before optimization
```python
@env.task
async def process_item(item: dict) -> dict:
# Very fast operation (~100ms)
return {"processed": item["id"]}
@env.task
async def process_dataset(items: list[dict]) -> list[dict]:
# Create 1M tasks
results = await asyncio.gather(*[process_item(item) for item in items])
return results
```
**Issues**:
- 1M tasks created (exceeds UI limit)
- Task overhead >> task runtime (100ms task, seconds of overhead)
- High load on Queue Service and object storage
### After optimization
```python
# Use reusable containers
env = flyte.TaskEnvironment(
name="optimized-pipeline",
reuse_policy=flyte.ReusePolicy(
replicas=(5, 20),
concurrency=10,
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
)
@env.task
async def process_batch(items: list[dict]) -> list[dict]:
# Process batch of items
return [{"processed": item["id"]} for item in items]
@env.task
async def process_dataset(items: list[dict]) -> list[dict]:
# Create 1000 tasks (batch size 1000)
batch_size = 1000
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
results = await flyte.map(process_batch, batches)
return [item for batch in results for item in batch]
```
**Improvements**:
- 1000 tasks instead of 1M (within limits)
- Batch runtime ~100 seconds (100ms Γ 1000 items)
- Reusable containers eliminate startup overhead
- Concurrency enables high throughput (200 concurrent tasks max)
## When to contact the Union team
Reach out to the Union team if you:
- Need more than 50k actions per run
- Want to use high-performance metastores (Redis, PostgreSQL) instead of object stores
- Have specific performance requirements or constraints
- Need help profiling and optimizing your workflows
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps ===
# Configure apps
`[[AppEnvironment]]`s allows you to configure the environment in which your app runs, including the container image, compute resources, secrets, domains, scaling behavior, and more.
Similar to `[[TaskEnvironment]]`, configuration can be set when creating the `[[AppEnvironment]]` object. Unlike tasks, apps are long-running services, so they have additional configuration options specific to web services:
- `port`: What port the app listens on
- `command` and `args`: How to start the app
- `scaling`: Autoscaling configuration for handling variable load
- `domain`: Custom domains and subdomains for your app
- `requires_auth`: Whether the app requires authentication to access
- `depends_on`: Other app or task environments that the app depends on
## Hello World example
Here's a complete example of deploying a simple Streamlit "hello world" app with a custom subdomain:
```
"""A basic "Hello World" app example with custom subdomain."""
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1")
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="hello-world-app",
image=image,
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
domain=flyte.app.Domain(subdomain="hello"),
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
# Deploy the app
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/hello-world-app.py)
This example demonstrates:
- Creating a custom Docker image with Streamlit
- Setting the `args` to run the Streamlit hello app
- Configuring the port
- Setting resource limits
- Disabling authentication (for public access)
- Using a custom subdomain
Once deployed, your app will be accessible at the generated URL or your custom subdomain.
## Differences from TaskEnvironment
While `AppEnvironment` inherits from `Environment` (the same base class as `TaskEnvironment`), it has several app-specific parameters:
| Parameter | AppEnvironment | TaskEnvironment | Description |
|-----------|----------------|-----------------|-------------|
| `type` | β | β | Type of app (e.g., "FastAPI", "Streamlit") |
| `port` | β | β | Port the app listens on |
| `args` | β | β | Arguments to pass to the app |
| `command` | β | β | Command to run the app |
| `requires_auth` | β | β | Whether app requires authentication |
| `scaling` | β | β | Autoscaling configuration |
| `domain` | β | β | Custom domain/subdomain |
| `links` | β | β | Links to include in the App UI page |
| `include` | β | β | Files to include in app |
| `inputs` | β | β | Inputs to pass to app |
| `cluster_pool` | β | β | Cluster pool for deployment |
Parameters like `image`, `resources`, `secrets`, `env_vars`, and `depends_on` are shared between both environment types. See the **Configure tasks** docs for details on these shared parameters.
## Configuration topics
Learn more about configuring apps:
- **Configure apps > App environment settings**: Images, resources, secrets, and app-specific settings like `type`, `port`, `args`, `requires_auth`
- **Configure apps > App environment settings > App startup**: Understanding the difference between `args` and `command`
- **Configure apps > Including additional files**: How to include additional files needed by your app
- **Configure apps > Passing inputs into app environments**: Pass inputs to your app at deployment time
- **Configure apps > **Autoscaling apps****: Configure scaling up and down based on traffic with idle TTL
- **Configure apps > Apps depending on other environments**: Use `depends_on` to deploy dependent apps together
## Subpages
- **Configure apps > App environment settings**
- **Configure apps > Including additional files**
- **Configure apps > Passing inputs into app environments**
- **Configure apps > {{docs-fragment basic-scaling}}**
- **Configure apps > Apps depending on other environments**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/app-environment-settings ===
# App environment settings
`[[AppEnvironment]]`s control how your apps run in Flyte, including images, resources, secrets, startup behavior, and autoscaling.
## Shared environment settings
`[[AppEnvironment]]`s share many configuration options with `[[TaskEnvironment]]`s:
- **Images**: See **Configure tasks > Container images** for details on creating and using container images
- **Resources**: See **Configure tasks > Resources** for CPU, memory, GPU, and storage configuration
- **Secrets**: See **Configure tasks > Secrets** for injecting secrets into your app
- **Environment variables**: Set via the `env_vars` parameter (same as tasks)
- **Cluster pools**: Specify via the `cluster_pool` parameter
## App-specific environment settings
### `type`
The `type` parameter is an optional string that identifies what kind of app this is. It's used for organizational purposes and may be used by the UI or tooling to display or filter apps.
```python
app_env = flyte.app.AppEnvironment(
name="my-fastapi-app",
type="FastAPI",
# ...
)
```
When using specialized app environments like `FastAPIAppEnvironment`, the type is automatically set. For custom apps, you can set it to any string value.
### `port`
The `port` parameter specifies which port your app listens on. It can be an integer or a `Port` object.
```python
# Using an integer (simple case)
app_env = flyte.app.AppEnvironment(name="my-app", port=8080, ...)
# Using a Port object (more control)
app_env = flyte.app.AppEnvironment(
name="my-app",
port=flyte.app.Port(port=8080),
# ...
)
```
The default port is `8080`. Your app should listen on this port (or the port you specify).
> [!NOTE]
> Ports 8012, 8022, 8112, 9090, and 9091 are reserved and cannot be used for apps.
### `args`
The `args` parameter specifies arguments to pass to your app's command. This is typically used when you need to pass additional arguments to the command specified in `command`, or when using the default command behavior.
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
# ...
)
```
`args` can be either a string (which will be shell-split) or a list of strings:
```python
# String form (will be shell-split)
args="--option1 value1 --option2 value2"
# List form (more explicit)
args=["--option1", "value1", "--option2", "value2"]
```
#### Environment variable substitution
Environment variables are automatically substituted in `args` strings when they start with the `$` character. This works for both:
- Values from `env_vars`
- Secrets that are specified as environment variables (via `as_env_var` in `flyte.Secret`)
The `$VARIABLE_NAME` syntax will be replaced with the actual environment variable value at runtime:
```python
# Using env_vars
app_env = flyte.app.AppEnvironment(
name="my-app",
env_vars={"API_KEY": "secret-key-123"},
args="--api-key $API_KEY", # $API_KEY will be replaced with "secret-key-123"
# ...
)
# Using secrets
app_env = flyte.app.AppEnvironment(
name="my-app",
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
args=["--api-key", "$AUTH_SECRET"], # $AUTH_SECRET will be replaced with the secret value
# ...
)
```
This is particularly useful for passing API keys or other sensitive values to command-line arguments without hardcoding them in your code. The substitution happens at runtime, ensuring secrets are never exposed in your code or configuration files.
> [!TIP]
> For most `AppEnvironment`s, use `args` instead of `command` to specify the app startup command
> in the container. This is because `args` will use the `fserve` command to run the app, which
> unlocks features like local code bundling and file/directory mounting via input injection.
### `command`
The `command` parameter specifies the full command to run your app. If not specified, Flyte will use a default command that runs your app via `fserve`, which is the Python executable provided
by `flyte` to run apps.
```python
# Explicit command
app_env = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# ...
)
# Using default command (recommended for most cases)
# When command is None, Flyte generates a command based on your app configuration
app_env = flyte.app.AppEnvironment(name="my-app", ...) # command=None by default
```
> [!TIP]
> For most apps, especially when using specialized app environments like `FastAPIAppEnvironment`, you don't need to specify `command` as it's automatically configured. Use `command` when you need
> to specify the raw container command, e.g. when running a non-Python app or when you have all
> of the dependencies and data used by the app available in the container.
### `requires_auth`
The `requires_auth` parameter controls whether the app requires authentication to access. By default, apps require authentication (`requires_auth=True`).
```python
# Public app (no authentication required)
app_env = flyte.app.AppEnvironment(
name="public-dashboard",
requires_auth=False,
# ...
)
# Private app (authentication required - default)
app_env = flyte.app.AppEnvironment(
name="internal-api",
requires_auth=True,
# ...
) # Default
```
When `requires_auth=True`, users must authenticate with Flyte to access the app. When `requires_auth=False`, the app is publicly accessible (though it may still require API keys or other app-level authentication).
### `domain`
The `domain` parameter specifies a custom domain or subdomain for your app. Use `flyte.app.Domain` to configure a subdomain or custom domain.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
domain=flyte.app.Domain(subdomain="myapp"),
# ...
)
```
### `links`
The `links` parameter adds links to the App UI page. Use `flyte.app.Link` objects to specify relative or absolute links with titles.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
links=[
flyte.app.Link(path="/docs", title="API Documentation", is_relative=True),
flyte.app.Link(path="/health", title="Health Check", is_relative=True),
flyte.app.Link(path="https://www.example.com", title="External link", is_relative=False),
],
# ...
)
```
### `include`
The `include` parameter specifies files and directories to include in the app bundle. Use glob patterns or explicit paths to include code files needed by your app.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
include=["*.py", "models/", "utils/", "requirements.txt"],
# ...
)
```
> [!NOTE]
> Learn more about including additional files in your app deployment **Configure apps > Including additional files**.
### `inputs`
The `inputs` parameter passes inputs to your app at deployment time. Inputs can be primitive values, files, directories, or delayed values like `RunOutput` or `AppEndpoint`.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
inputs=[
flyte.app.Input(name="config", value="foo", env_var="BAR"),
flyte.app.Input(name="model", value=flyte.io.File(path="s3://bucket/model.pkl"), mount="/mnt/model"),
flyte.app.Input(name="data", value=flyte.io.File(path="s3://bucket/data.pkl"), mount="/mnt/data"),
],
# ...
)
```
> [!NOTE]
> Learn more about passing inputs to your app at deployment time **Configure apps > Passing inputs into app environments**.
### `scaling`
The `scaling` parameter configures autoscaling behavior for your app. Use `flyte.app.Scaling` to set replica ranges and idle TTL.
```python
app_env = flyte.app.AppEnvironment(
name="my-app",
scaling=flyte.app.Scaling(
replicas=(1, 5),
scaledown_after=300, # Scale down after 5 minutes of idle time
),
# ...
)
```
> [!NOTE]
> Learn more about autoscaling apps **Configure apps > {{docs-fragment basic-scaling}}**.
### `depends_on`
The `depends_on` parameter specifies environment dependencies. When you deploy an app, all dependencies are deployed first.
```python
backend_env = flyte.app.AppEnvironment(name="backend-api", ...)
frontend_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[backend_env], # backend-api will be deployed first
# ...
)
```
> [!NOTE]
> Learn more about app environment dependencies **Configure apps > Apps depending on other environments**.
## App startup
Understanding the difference between `args` and `command` is crucial for properly configuring how your app starts.
### Command vs args
In container terminology:
- **`command`**: The executable or entrypoint that runs
- **`args`**: Arguments passed to that command
In Flyte apps:
- **`command`**: The full command to run your app (for example, `"streamlit hello --server.port 8080"`)
- **`args`**: Arguments to pass to your app's command (used with the default Flyte command or your custom command)
### Default startup behavior
When you don't specify a `command`, Flyte generates a default command that uses `fserve` to run your app. This default command handles:
- Setting up the code bundle
- Configuring the version
- Setting up project/domain context
- Injecting inputs if provided
The default command looks like:
```bash
fserve --version --project --domain --
```
So if you specify `args`, they'll be appended after the `--` separator.
### Startup examples
#### Using args with default command
When you use `args` without specifying `command`, the args are passed to the default Flyte command:
```
"""Examples showing different app startup configurations."""
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# command is automatically set to: uvicorn : --port 8080
# You typically don't need to specify command or args
)
# {{/docs-fragment fastapi-auto-command}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py)
This effectively runs:
```bash
fserve --version ... --project ... --domain ... -- streamlit run main.py --server.port 8080
```
#### Using explicit command
When you specify a `command`, it completely replaces the default command:
```
"""Examples showing different app startup configurations."""
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# command is automatically set to: uvicorn : --port 8080
# You typically don't need to specify command or args
)
# {{/docs-fragment fastapi-auto-command}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py)
This runs exactly:
```bash
streamlit hello --server.port 8080
```
#### Using command with args
You can combine both, though this is less common:
```
"""Examples showing different app startup configurations."""
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# command is automatically set to: uvicorn : --port 8080
# You typically don't need to specify command or args
)
# {{/docs-fragment fastapi-auto-command}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py)
#### FastAPIAppEnvironment example
When using `FastAPIAppEnvironment`, the command is automatically configured to run uvicorn:
```
"""Examples showing different app startup configurations."""
import flyte
import flyte.app
# {{docs-fragment args-with-default-command}}
# Using args with default command
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
# command is None, so default Flyte command is used
)
# {{/docs-fragment args-with-default-command}}
# {{docs-fragment explicit-command}}
# Using explicit command
app_env2 = flyte.app.AppEnvironment(
name="streamlit-hello",
command="streamlit hello --server.port 8080",
port=8080,
# No args needed since command includes everything
)
# {{/docs-fragment explicit-command}}
# {{docs-fragment command-with-args}}
# Using command with args
app_env3 = flyte.app.AppEnvironment(
name="custom-app",
command="python -m myapp",
args="--option1 value1 --option2 value2",
# This runs: python -m myapp --option1 value1 --option2 value2
)
# {{/docs-fragment command-with-args}}
# {{docs-fragment fastapi-auto-command}}
# FastAPIAppEnvironment automatically sets command
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
# command is automatically set to: uvicorn : --port 8080
# You typically don't need to specify command or args
)
# {{/docs-fragment fastapi-auto-command}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-startup-examples.py)
The `FastAPIAppEnvironment` automatically:
1. Detects the module and variable name of your FastAPI app
2. Sets the command to run `uvicorn : --port `
3. Handles all the startup configuration for you
### Startup best practices
1. **Use specialized app environments** when available (for example, `FastAPIAppEnvironment`) β they handle command setup automatically.
2. **Use `args`** when you need code bundling and input injection.
3. **Use `command`** for simple, standalone apps that don't need code bundling.
4. **Always set `port`** to match what your app actually listens on.
5. **Use `include`** with `args` to bundle your app code files.
## Complete example
Here's a complete example showing various environment, startup, and scaling settings:
```
"""Complete example showing various environment settings."""
import flyte
import flyte.app
# {{docs-fragment complete-example}}
# Create a custom image
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi==0.104.1",
"uvicorn==0.24.0",
"python-multipart==0.0.6",
)
# Configure app with various settings
app_env = flyte.app.AppEnvironment(
name="my-api",
type="FastAPI",
image=image,
port=8080,
resources=flyte.Resources(
cpu="2",
memory="4Gi",
),
secrets=flyte.Secret(key="my-api-key", as_env_var="API_KEY"),
env_vars={
"LOG_LEVEL": "INFO",
"ENVIRONMENT": "production",
},
requires_auth=False, # Public API
cluster_pool="production-pool",
description="My production FastAPI service",
)
# {{/docs-fragment complete-example}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/environment-settings-example.py)
This example demonstrates:
- Setting a custom `type` identifier
- Configuring the port
- Specifying compute resources
- Injecting secrets as environment variables
- Setting environment variables
- Making the app publicly accessible
- Targeting a specific cluster pool
- Adding a description
- Configuring autoscaling behavior
For more details on shared settings like images, resources, and secrets, refer to the **Configure tasks** documentation.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/including-additional-files ===
# Including additional files
When your app needs additional files beyond the main script (like utility modules, configuration files, or data files), you can use the `include` parameter to specify which files to bundle with your app.
## How include works
The `include` parameter takes a list of file paths (relative to the directory containing your app definition). These files are bundled together and made available in the app container at runtime.
```python
include=["main.py", "utils.py", "config.yaml"]
```
## When to use include
Use `include` when:
- Your app spans multiple Python files (modules)
- You have configuration files that your app needs
- You have data files or templates your app uses
- You want to ensure specific files are available in the container
> [!NOTE]
> If you're using specialized app environments like `FastAPIAppEnvironment`, Flyte automatically detects and includes the necessary files, so you may not need to specify `include` explicitly.
## Examples
### Multi-file Streamlit app
```
"""A custom Streamlit app with multiple files."""
import pathlib
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="streamlit-custom-app",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py"], # Include your app files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.deploy(app_env)
print(f"App URL: {app[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/custom_streamlit.py)
In this example:
- `main.py` is your main Streamlit app file
- `utils.py` contains helper functions used by `main.py`
- Both files are included in the app bundle
### Multi-file FastAPI app
```
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py)
### App with configuration files
```python
include=["app.py", "config.yaml", "templates/"]
```
## File discovery
When using specialized app environments like `FastAPIAppEnvironment`, Flyte uses code introspection to automatically discover and include the necessary files. This means you often don't need to manually specify `include`.
However, if you have files that aren't automatically detected (like configuration files, data files, or templates), you should explicitly list them in `include`.
## Path resolution
Files in `include` are resolved relative to the directory containing your app definition file. For example:
```
project/
βββ apps/
β βββ app.py # Your app definition
β βββ utils.py # Included file
β βββ config.yaml # Included file
```
In `app.py`:
```python
include=["utils.py", "config.yaml"] # Relative to apps/ directory
```
## Best practices
1. **Only include what you need**: Don't include unnecessary files as it increases bundle size
2. **Use relative paths**: Always use paths relative to your app definition file
3. **Include directories**: You can include entire directories, but be mindful of size
4. **Test locally**: Verify your includes work by testing locally before deploying
5. **Check automatic discovery**: Specialized app environments may already include files automatically
## Limitations
- Large files or directories can slow down deployment
- Binary files are supported but consider using data storage (S3, etc.) for very large files
- The bundle size is limited by your Flyte cluster configuration
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/passing-inputs ===
# Passing inputs into app environments
`[[AppEnvironment]]`s support various input types that can be passed at deployment time. This includes primitive values, files, directories, and delayed values like `RunOutput` and `AppEndpoint`.
## Input types overview
There are several input types:
- **Primitive values**: Strings, numbers, booleans
- **Files**: `flyte.io.File` objects
- **Directories**: `flyte.io.Dir` objects
- **Delayed values**: `RunOutput` (from task runs) or `AppEndpoint` ( apps)
## Basic input types
```
"""Examples showing different ways to pass inputs into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-input-types}}
# String inputs
app_env = flyte.app.AppEnvironment(
name="configurable-app",
inputs=[
flyte.app.Input(name="environment", value="production"),
flyte.app.Input(name="log_level", value="INFO"),
],
# ...
)
# File inputs
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory inputs
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
inputs=[
flyte.app.Input(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-input-types}}
# {{docs-fragment runoutput-example}}
# Delayed inputs with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app input
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed inputs with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-inputs-examples.py)
## Delayed values
Delayed values are inputs whose actual values are materialized at deployment time.
### RunOutput
Use `RunOutput` to pass outputs from task runs as app inputs:
```
"""Examples showing different ways to pass inputs into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-input-types}}
# String inputs
app_env = flyte.app.AppEnvironment(
name="configurable-app",
inputs=[
flyte.app.Input(name="environment", value="production"),
flyte.app.Input(name="log_level", value="INFO"),
],
# ...
)
# File inputs
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory inputs
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
inputs=[
flyte.app.Input(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-input-types}}
# {{docs-fragment runoutput-example}}
# Delayed inputs with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app input
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed inputs with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-inputs-examples.py)
The `type` argument is required and must be one of `string`, `file`, or `directory`.
When the app is deployed, it will make the remote calls needed to figure out the
actual value of the input.
### AppEndpoint
Use `AppEndpoint` to pass endpoints from other apps:
```
"""Examples showing different ways to pass inputs into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-input-types}}
# String inputs
app_env = flyte.app.AppEnvironment(
name="configurable-app",
inputs=[
flyte.app.Input(name="environment", value="production"),
flyte.app.Input(name="log_level", value="INFO"),
],
# ...
)
# File inputs
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory inputs
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
inputs=[
flyte.app.Input(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-input-types}}
# {{docs-fragment runoutput-example}}
# Delayed inputs with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app input
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed inputs with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-inputs-examples.py)
The endpoint URL will be injected as the input value when the app starts.
This is particularly useful when you want to chain apps together (for example, a frontend app calling a backend app), without hardcoding URLs.
## Overriding inputs at serve time
You can override input values when serving apps (this is not supported for deployment):
```python
# Override inputs when serving
app = flyte.with_servecontext(
input_values={"my-app": {"model_path": "s3://bucket/new-model.pkl"}}
).serve(app_env)
```
> [!NOTE]
> Input overrides are only available when using `flyte.serve()` or `flyte.with_servecontext().serve()`.
> The `flyte.deploy()` function does not support input overrides - inputs must be specified in the `AppEnvironment` definition.
This is useful for:
- Testing different configurations during development
- Using different models or data sources for testing
- A/B testing different app configurations
## Example: FastAPI app with configurable model
Here's a complete example showing how to use inputs in a FastAPI app:
```
"""Example: FastAPI app with configurable model input."""
from contextlib import asynccontextmanager
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
import os
import flyte
import joblib
# {{docs-fragment model-serving-api}}
state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Access input via environment variable
model = joblib.load(os.getenv("MODEL_PATH", "/app/models/default.pkl"))
state["model"] = model
yield
app = FastAPI(lifespan=lifespan)
app_env = FastAPIAppEnvironment(
name="model-serving-api",
app=app,
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/default.pkl"),
mount="/app/models",
env_var="MODEL_PATH",
),
],
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "scikit-learn"
),
resources=flyte.Resources(cpu=2, memory="2Gi"),
requires_auth=False,
)
@app.get("/predict")
async def predict(data: dict):
model = state["model"]
return {"prediction": model.predict(data)}
# {{/docs-fragment model-serving-api}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/app-inputs-fastapi-example.py)
## Example: Using RunOutput for model serving
```
"""Examples showing different ways to pass inputs into apps."""
import flyte
import flyte.app
import flyte.io
# {{docs-fragment basic-input-types}}
# String inputs
app_env = flyte.app.AppEnvironment(
name="configurable-app",
inputs=[
flyte.app.Input(name="environment", value="production"),
flyte.app.Input(name="log_level", value="INFO"),
],
# ...
)
# File inputs
app_env2 = flyte.app.AppEnvironment(
name="app-with-model",
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.pkl"),
mount="/app/models",
),
],
# ...
)
# Directory inputs
app_env3 = flyte.app.AppEnvironment(
name="app-with-data",
inputs=[
flyte.app.Input(
name="data_dir",
value=flyte.io.Dir("s3://bucket/data/"),
mount="/app/data",
),
],
# ...
)
# {{/docs-fragment basic-input-types}}
# {{docs-fragment runoutput-example}}
# Delayed inputs with RunOutput
env = flyte.TaskEnvironment(name="training-env")
@env.task
async def train_model() -> flyte.io.File:
# ... training logic ...
return await flyte.io.File.from_local("/tmp/trained-model.pkl")
# Use the task output as an app input
app_env4 = flyte.app.AppEnvironment(
name="serving-app",
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(type="file", run_name="training_run", task_name="train_model"),
mount="/app/model",
),
],
# ...
)
# {{/docs-fragment runoutput-example}}
# {{docs-fragment appendpoint-example}}
# Delayed inputs with AppEndpoint
app1_env = flyte.app.AppEnvironment(name="backend-api")
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL", # app1_env's endpoint will be available as an environment variable
),
],
# ...
)
# {{/docs-fragment appendpoint-example}}
# {{docs-fragment runoutput-serving-example}}
# Example: Using RunOutput for model serving
import joblib
from sklearn.ensemble import RandomForestClassifier
from flyte.app.extras import FastAPIAppEnvironment
from fastapi import FastAPI
# Training task
training_env = flyte.TaskEnvironment(name="training-env")
@training_env.task
async def train_model_task() -> flyte.io.File:
"""Train a model and return it."""
model = RandomForestClassifier()
# ... training logic ...
path = "./trained-model.pkl"
joblib.dump(model, path)
return await flyte.io.File.from_local(path)
# Serving app that uses the trained model
app = FastAPI()
serving_env = FastAPIAppEnvironment(
name="model-serving-app",
app=app,
inputs=[
flyte.app.Input(
name="model",
value=flyte.app.RunOutput(
type="file",
task_name="training-env.train_model_task"
),
mount="/app/model",
env_var="MODEL_PATH",
),
],
)
# {{/docs-fragment runoutput-serving-example}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/passing-inputs-examples.py)
## Accessing inputs in your app
How you access inputs depends on how they're configured:
1. **Environment variables**: If `env_var` is specified, the input is available as an environment variable
2. **Mounted paths**: File and directory inputs are mounted at the specified path
3. **Flyte SDK**: Use the Flyte SDK to access input values programmatically
```python
import os
# Input with env_var specified
env = flyte.app.AppEnvironment(
name="my-app",
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/model.pkl"),
mount="/app/models/model.pkl",
env_var="MODEL_PATH",
),
# ...
)
# Access in the app via the environment variable
API_KEY = os.getenv("API_KEY")
# Access in the app via the mounted path
with open("/app/models/model.pkl", "rb") as f:
model = pickle.load(f)
# Access in the app via the Flyte SDK (for string inputs)
input_value = flyte.app.get_input("config") # Returns string value
```
## Best practices
1. **Use delayed inputs**: Leverage `RunOutput` and `AppEndpoint` to create app dependencies between tasks and apps, or app-to-app chains.
2. **Override for testing**: Use the `input_values` parameter when serving to test different configurations without changing code.
3. **Mount paths clearly**: Use descriptive mount paths for file/directory inputs so your app code is easy to understand.
4. **Use environment variables**: For simple constants that you can hard-code, use `env_var` to inject values as environment variables.
5. **Production deployments**: For production, define inputs in the `AppEnvironment` rather than overriding them at deploy time.
## Limitations
- Large files/directories can slow down app startup.
- Input overrides are only available when using `flyte.with_servecontext(...).serve(...)`.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/auto-scaling-apps ===
## Autoscaling apps
Flyte apps support autoscaling, allowing them to scale up and down based on traffic. This helps optimize costs by scaling down when there's no traffic and scaling up when needed.
### Scaling configuration
The `scaling` parameter uses a `Scaling` object to configure autoscaling behavior:
```python
scaling=flyte.app.Scaling(
replicas=(min_replicas, max_replicas),
scaledown_after=idle_ttl_seconds,
)
```
#### Parameters
- **`replicas`**: A tuple `(min_replicas, max_replicas)` specifying the minimum and maximum number of replicas.
- **`scaledown_after`**: Time in seconds to wait before scaling down when idle (idle TTL).
### Basic scaling example
Here's a simple example with scaling from 0 to 1 replica:
```
"""Examples showing different autoscaling configurations."""
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py)
This configuration:
- Starts with 0 replicas (no running instances)
- Scales up to 1 replica when there's traffic
- Scales back down to 0 after 5 minutes (300 seconds) of no traffic
### Scaling patterns
#### Always-on app
For apps that need to always be running:
```
"""Examples showing different autoscaling configurations."""
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py)
#### Scale-to-zero app
For apps that can scale to zero when idle:
```
"""Examples showing different autoscaling configurations."""
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py)
#### High-availability app
For apps that need multiple replicas for availability:
```
"""Examples showing different autoscaling configurations."""
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py)
#### Burstable app
For apps with variable load:
```
"""Examples showing different autoscaling configurations."""
import flyte
import flyte.app
# {{docs-fragment basic-scaling}}
# Basic example: scale from 0 to 1 replica
app_env = flyte.app.AppEnvironment(
name="autoscaling-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale from 0 to 1 replica
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# ...
)
# {{/docs-fragment basic-scaling}}
# {{docs-fragment always-on}}
# Always-on app
app_env2 = flyte.app.AppEnvironment(
name="always-on-api",
scaling=flyte.app.Scaling(
replicas=(1, 1), # Always keep 1 replica running
# scaledown_after is ignored when min_replicas > 0
),
# ...
)
# {{/docs-fragment always-on}}
# {{docs-fragment scale-to-zero}}
# Scale-to-zero app
app_env3 = flyte.app.AppEnvironment(
name="scale-to-zero-app",
scaling=flyte.app.Scaling(
replicas=(0, 1), # Can scale down to 0
scaledown_after=600, # Scale down after 10 minutes of inactivity
),
# ...
)
# {{/docs-fragment scale-to-zero}}
# {{docs-fragment high-availability}}
# High-availability app
app_env4 = flyte.app.AppEnvironment(
name="ha-api",
scaling=flyte.app.Scaling(
replicas=(2, 5), # Keep at least 2, scale up to 5
scaledown_after=300, # Scale down after 5 minutes
),
# ...
)
# {{/docs-fragment high-availability}}
# {{docs-fragment burstable}}
# Burstable app
app_env5 = flyte.app.AppEnvironment(
name="bursty-app",
scaling=flyte.app.Scaling(
replicas=(1, 10), # Start with 1, scale up to 10 under load
scaledown_after=180, # Scale down quickly after 3 minutes
),
# ...
)
# {{/docs-fragment burstable}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/configure-apps/autoscaling-examples.py)
### Idle TTL (Time To Live)
The `scaledown_after` parameter (idle TTL) determines how long an app instance can be idle before it's scaled down.
#### Considerations
- **Too short**: May cause frequent scale up/down cycles, leading to cold starts.
- **Too long**: Keeps resources running unnecessarily, increasing costs.
- **Optimal**: Balance between cost and user experience.
#### Common idle TTL values
- **Development/Testing**: 60-180 seconds (1-3 minutes) - quick scale down for cost savings.
- **Production APIs**: 300-600 seconds (5-10 minutes) - balance cost and responsiveness.
- **Batch processing**: 900-1800 seconds (15-30 minutes) - longer to handle bursts.
- **Always-on**: Set `min_replicas > 0` - never scale down.
### Autoscaling best practices
1. **Start conservative**: Begin with longer idle TTL values and adjust based on usage.
2. **Monitor cold starts**: Track how long it takes for your app to become ready after scaling up.
3. **Consider costs**: Balance idle TTL between cost savings and user experience.
4. **Use appropriate min replicas**: Set `min_replicas > 0` for critical apps that need to be always available.
5. **Test scaling behavior**: Verify your app handles scale up/down correctly (for example, state management and connections).
### Autoscaling limitations
- Scaling is based on traffic/request patterns, not CPU/memory utilization.
- Cold starts may occur when scaling from zero.
- Stateful apps need careful design to handle scaling (use external state stores).
- Maximum replicas are limited by your cluster capacity.
### Autoscaling troubleshooting
**App scales down too quickly:**
- Increase `scaledown_after` value.
- Set `min_replicas > 0` if the app needs to stay warm.
**App doesn't scale up fast enough:**
- Ensure your cluster has capacity.
- Check if there are resource constraints.
**Cold starts are too slow:**
- Pre-warm with `min_replicas = 1`.
- Optimize app startup time.
- Consider using faster storage for model loading.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/configure-apps/apps-depending-on-environments ===
# Apps depending on other environments
The `depends_on` parameter allows you to specify that one app depends on another app (or task environment). When you deploy an app with `depends_on`, Flyte ensures that all dependencies are deployed first.
## Basic usage
Use `depends_on` to specify a list of environments that this app depends on:
```python
app1_env = flyte.app.AppEnvironment(name="backend-api", ...)
app2_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[app1_env], # Ensure backend-api is deployed first
# ...
)
```
When you deploy `app2_env`, Flyte will:
1. First deploy `app1_env` (if not already deployed)
2. Then deploy `app2_env`
3. Make sure `app1_env` is available before `app2_env` starts
## Example: App calling another app
Here's a complete example where one FastAPI app calls another:
```
"""Example of one app calling another app."""
import httpx
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
)
# {{docs-fragment backend-app}}
app1 = FastAPI(
title="App 1",
description="A FastAPI app that runs some computations",
)
env1 = FastAPIAppEnvironment(
name="app1-is-called-by-app2",
app=app1,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment backend-app}}
# {{docs-fragment frontend-app}}
app2 = FastAPI(
title="App 2",
description="A FastAPI app that proxies requests to another FastAPI app",
)
env2 = FastAPIAppEnvironment(
name="app2-calls-app1",
app=app2,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
depends_on=[env1], # Depends on backend-api
)
# {{/docs-fragment frontend-app}}
# {{docs-fragment backend-endpoint}}
@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
# {{/docs-fragment backend-endpoint}}
# {{docs-fragment frontend-endpoints}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
return env1.endpoint # Access the backend endpoint
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
"""Proxy that calls the backend app."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{env1.endpoint}/greeting/{name}")
response.raise_for_status()
return response.json()
# {{/docs-fragment frontend-endpoints}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
deployments = flyte.deploy(env2)
print(f"Deployed FastAPI app: {deployments[0].env_repr()}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/app_calling_app.py)
When you deploy `env2`, Flyte will:
1. Deploy `env1` first (backend-api)
2. Wait for `env1` to be ready
3. Deploy `env2` (frontend-api)
4. `env2` can then access `env1.endpoint` to make requests
## Dependency chain
You can create chains of dependencies:
```python
app1_env = flyte.app.AppEnvironment(name="service-1", ...)
app2_env = flyte.app.AppEnvironment(name="service-2", depends_on=[app1_env], ...)
app3_env = flyte.app.AppEnvironment(name="service-3", depends_on=[app2_env], ...)
# Deploying app3_env will deploy in order: app1_env -> app2_env -> app3_env
```
## Multiple dependencies
An app can depend on multiple environments:
```python
backend_env = flyte.app.AppEnvironment(name="backend", ...)
database_env = flyte.app.AppEnvironment(name="database", ...)
api_env = flyte.app.AppEnvironment(
name="api",
depends_on=[backend_env, database_env], # Depends on both
# ...
)
```
When deploying `api_env`, both `backend_env` and `database_env` will be deployed first (they may be deployed in parallel if they don't depend on each other).
## Using AppEndpoint for dependency URLs
When one app depends on another, you can use `AppEndpoint` to get the URL:
```python
backend_env = flyte.app.AppEnvironment(name="backend-api", ...)
frontend_env = flyte.app.AppEnvironment(
name="frontend-app",
depends_on=[backend_env],
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
),
],
# ...
)
```
The `backend_url` input will be automatically set to the backend app's endpoint URL.
You can get this value in your app code using `flyte.app.get_input("backend_url")`.
## Deployment behavior
When deploying with `flyte.deploy()`:
```python
# Deploy the app (dependencies are automatically deployed)
deployments = flyte.deploy(env2)
# All dependencies are included in the deployment plan
for deployment in deployments:
print(f"Deployed: {deployment.env.name}")
```
Flyte will:
1. Build a deployment plan that includes all dependencies
2. Deploy dependencies in the correct order
3. Ensure dependencies are ready before deploying dependent apps
## Task environment dependencies
You can also depend on task environments:
```python
task_env = flyte.TaskEnvironment(name="training-env", ...)
serving_env = flyte.app.AppEnvironment(
name="serving-app",
depends_on=[task_env], # Can depend on task environments too
# ...
)
```
This ensures the task environment is available when the app is deployed (useful if the app needs to call tasks in that environment).
## Best practices
1. **Explicit dependencies**: Always use `depends_on` to make app dependencies explicit
2. **Circular dependencies**: Avoid circular dependencies (app A depends on B, B depends on A)
3. **Dependency order**: Design your dependency graph to be a DAG (Directed Acyclic Graph)
4. **Endpoint access**: Use `AppEndpoint` to pass dependency URLs as inputs
5. **Document dependencies**: Make sure your app documentation explains its dependencies
## Example: A/B testing with dependencies
Here's an example of an A/B testing setup where a root app depends on two variant apps:
```python
app_a = FastAPI(title="Variant A")
app_b = FastAPI(title="Variant B")
root_app = FastAPI(title="Root App")
env_a = FastAPIAppEnvironment(name="app-a-variant", app=app_a, ...)
env_b = FastAPIAppEnvironment(name="app-b-variant", app=app_b, ...)
env_root = FastAPIAppEnvironment(
name="root-ab-testing-app",
app=root_app,
depends_on=[env_a, env_b], # Depends on both variants
# ...
)
```
The root app can route traffic to either variant A or B based on A/B testing logic, and both variants will be deployed before the root app starts.
## Limitations
- Circular dependencies are not supported
- Dependencies must be in the same project/domain
- Dependency deployment order is deterministic but dependencies at the same level may deploy in parallel
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps ===
# Build apps
This section covers how to build different types of apps with Flyte, including Streamlit dashboards, FastAPI REST APIs, vLLM and SGLang model servers, webhooks, and WebSocket applications.
> [!TIP]
> Go to **Getting started > Serving apps** to see a quick example of how to serve an app.
## App types
Flyte supports various types of apps:
- **UI dashboard apps**: Interactive web dashboards and data visualization tools like Streamlit and Gradio
- **Web API apps**: REST APIs, webhooks, and backend services like FastAPI and Flask
- **Model serving apps**: High-performance LLM serving with vLLM and SGLang
## Next steps
- **Build apps > Single-script apps**: The simplest way to build and deploy apps in a single Python script
- **Build apps > Multi-script apps**: Build FastAPI and Streamlit apps with multiple files
- **Build apps > App usage patterns**: Call apps from tasks, tasks from apps, and apps from apps
- **Build apps > Secret-based authentication**: Authenticate FastAPI apps using Flyte secrets
- **Build apps > Streamlit app**: Build interactive Streamlit dashboards
- **Build apps > FastAPI app**: Create REST APIs and backend services
- **Build apps > vLLM app**: Serve large language models with vLLM
- **Build apps > SGLang app**: Serve LLMs with SGLang for structured generation
## Subpages
- **Build apps > Single-script apps**
- **Build apps > Multi-script apps**
- **Build apps > App usage patterns**
- **Build apps > Secret-based authentication**
- **Build apps > Streamlit app**
- **Build apps > FastAPI app**
- **Build apps > vLLM app**
- **Build apps > SGLang app**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/single-script-apps ===
# Single-script apps
The simplest way to build and deploy an app with Flyte is to write everything in a single Python script. This approach is perfect for:
- **Quick prototypes**: Rapidly test ideas and concepts
- **Simple services**: Basic HTTP servers, APIs, or dashboards
- **Learning**: Understanding how Flyte apps work without complexity
- **Minimal examples**: Demonstrating core functionality
All the code for your appβthe application logic, the app environment configuration, and the deployment codeβlives in one file. This makes it easy to understand, share, and deploy.
## Plain Python HTTP server
The simplest possible app is a plain Python HTTP server using Python's built-in `http.server` module. This requires no external dependencies beyond the Flyte SDK.
```
"""A plain Python HTTP server example - the simplest possible app."""
import flyte
import flyte.app
from pathlib import Path
# {{docs-fragment server-code}}
# Create a simple HTTP server handler
from http.server import HTTPServer, BaseHTTPRequestHandler
class SimpleHandler(BaseHTTPRequestHandler):
"""A simple HTTP server handler."""
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"
Hello from Plain Python Server!
")
elif self.path == "/health":
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(b'{"status": "healthy"}')
else:
self.send_response(404)
self.end_headers()
# {{/docs-fragment server-code}}
# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="plain-python-server",
image=flyte.Image.from_debian_base(python_version=(3, 12)),
args=["python", file_name, "--server"],
port=8080,
resources=flyte.Resources(cpu="1", memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import sys
if "--server" in sys.argv:
server = HTTPServer(("0.0.0.0", 8080), SimpleHandler)
print("Server running on port 8080")
server.serve_forever()
else:
flyte.init_from_config(root_dir=Path(__file__).parent)
app = flyte.serve(app_env)
print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/plain_python_server.py)
**Key points**
- **No external dependencies**: Uses only Python's standard library
- **Simple handler**: Define request handlers as Python classes
- **Basic command**: Run the server with a simple Python command
- **Minimal resources**: Requires only basic CPU and memory
## Streamlit app
Streamlit makes it easy to build interactive web dashboards. Here's a complete single-script Streamlit app:
```
"""A single-script Streamlit app example."""
import pathlib
import streamlit as st
import flyte
import flyte.app
# {{docs-fragment streamlit-app}}
def main():
st.set_page_config(page_title="Simple Streamlit App", page_icon="π")
st.title("Hello from Streamlit!")
st.write("This is a simple single-script Streamlit app.")
name = st.text_input("What's your name?", "World")
st.write(f"Hello, {name}!")
if st.button("Click me!"):
st.balloons()
st.success("Button clicked!")
# {{/docs-fragment streamlit-app}}
# {{docs-fragment app-env}}
file_name = pathlib.Path(__file__).name
app_env = flyte.app.AppEnvironment(
name="streamlit-single-script",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1"
),
args=["streamlit", "run", file_name, "--server.port", "8080", "--", "--server"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
import sys
if "--server" in sys.argv:
main()
else:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(app_env)
print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit_single_script.py)
**Key points**
- **Interactive UI**: Streamlit provides widgets and visualizations out of the box
- **Single file**: All UI logic and deployment code in one script
- **Simple deployment**: Just specify the Streamlit command and port
- **Rich ecosystem**: Access to Streamlit's extensive component library
## FastAPI app
FastAPI is a modern, fast web framework for building APIs. Here's a minimal single-script FastAPI app:
```
"""A single-script FastAPI app example - the simplest FastAPI app."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = FastAPI(
title="Simple FastAPI App",
description="A minimal single-script FastAPI application",
version="1.0.0",
)
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health():
return {"status": "healthy"}
# {{/docs-fragment fastapi-app}}
# {{docs-fragment app-env}}
app_env = FastAPIAppEnvironment(
name="fastapi-single-script",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.serve(app_env)
print(f"Deployed: {app_deployment[0].url}")
print(f"API docs: {app_deployment[0].url}/docs")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi_single_script.py)
**Key points**
- **FastAPIAppEnvironment**: Automatically configures uvicorn and FastAPI
- **Type hints**: FastAPI uses Python type hints for automatic validation
- **Auto docs**: Interactive API documentation at `/docs` endpoint
- **Async support**: Built-in support for async/await patterns
## Running single-script apps
To run any of these examples:
1. **Save the script** to a file (e.g., `my_app.py`)
2. **Ensure you have a config file** (`./.flyte/config.yaml` or `./config.yaml`)
3. **Run the script**:
```bash
python my_app.py
```
Or using `uv`:
```bash
uv run my_app.py
```
The script will:
- Initialize Flyte from your config
- Deploy the app to your Union/Flyte instance
- Print the app URL
## When to use single-script apps
**Use single-script apps when:**
- Building prototypes or proof-of-concepts
- Creating simple services with minimal logic
- Learning how Flyte apps work
- Sharing complete, runnable examples
- Building demos or tutorials
**Consider multi-script apps when:**
- Your app grows beyond a few hundred lines
- You need to organize code into modules
- You want to reuse components across apps
- You're building production applications
See **Build apps > Multi-script apps** for examples of organizing apps across multiple files.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/multi-script-apps ===
# Multi-script apps
Real-world applications often span multiple files. This page shows how to build FastAPI and Streamlit apps with multiple Python files.
## FastAPI multi-script app
### Project structure
```
project/
βββ app.py # Main FastAPI app file
βββ module.py # Helper module
```
### Example: Multi-file FastAPI app
```
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py)
```
# {{docs-fragment helper-function}}
def function():
"""Helper function used by the FastAPI app."""
return {"message": "Hello from module.py!"}
# {{/docs-fragment helper-function}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/module.py)
### Automatic file discovery
`FastAPIAppEnvironment` automatically discovers and includes the necessary files by analyzing your imports. However, if you have files that aren't automatically detected (like configuration files or data files), you can explicitly include them:
```python
app_env = FastAPIAppEnvironment(
name="fastapi-with-config",
app=app,
include=["app.py", "module.py", "config.yaml"], # Explicit includes
# ...
)
```
## Streamlit multi-script app
### Project structure
```
project/
βββ main.py # Main Streamlit app
βββ utils.py # Utility functions
βββ components.py # Reusable components
```
### Example: Multi-file Streamlit app
```python
# main.py
import streamlit as st
from utils import process_data
from components import render_chart
st.title("Multi-file Streamlit App")
data = st.file_uploader("Upload data file")
if data:
processed = process_data(data)
render_chart(processed)
```
```python
# utils.py
import pandas as pd
def process_data(data_file):
"""Process uploaded data file."""
df = pd.read_csv(data_file)
# ... processing logic ...
return df
```
```python
# components.py
import streamlit as st
def render_chart(data):
"""Render a chart component."""
st.line_chart(data)
```
### Deploying multi-file Streamlit app
```python
import flyte
import flyte.app
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
)
app_env = flyte.app.AppEnvironment(
name="streamlit-multi-file",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py", "components.py"], # Include all files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.deploy(app_env)
print(f"App URL: {app[0].url}")
```
## Complex multi-file example
Here's a more complex example with multiple modules:
### Project structure
```
project/
βββ app.py
βββ models/
β βββ __init__.py
β βββ user.py
βββ services/
β βββ __init__.py
β βββ auth.py
βββ utils/
βββ __init__.py
βββ helpers.py
```
### Example code
```python
# app.py
from fastapi import FastAPI
from models.user import User
from services.auth import authenticate
from utils.helpers import format_response
app = FastAPI(title="Complex Multi-file App")
@app.get("/users/{user_id}")
async def get_user(user_id: int):
user = User(id=user_id, name="John Doe")
return format_response(user)
```
```python
# models/user.py
from pydantic import BaseModel
class User(BaseModel):
id: int
name: str
```
```python
# services/auth.py
def authenticate(token: str) -> bool:
# ... authentication logic ...
return True
```
```python
# utils/helpers.py
def format_response(data):
return {"data": data, "status": "success"}
```
### Deploying complex app
```python
from flyte.app.extras import FastAPIAppEnvironment
import flyte
app_env = FastAPIAppEnvironment(
name="complex-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"pydantic",
),
# Include all necessary files
include=[
"app.py",
"models/",
"services/",
"utils/",
],
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
```
## Best practices
1. **Use explicit includes**: For Streamlit apps, explicitly list all files in `include`
2. **Automatic discovery**: For FastAPI apps, `FastAPIAppEnvironment` handles most cases automatically
3. **Organize modules**: Use proper Python package structure with `__init__.py` files
4. **Test locally**: Test your multi-file app locally before deploying
5. **Include all dependencies**: Include all files that your app imports
## Troubleshooting
**Import errors:**
- Verify all files are included in the `include` parameter
- Check that file paths are correct (relative to app definition file)
- Ensure `__init__.py` files are included for packages
**Module not found:**
- Add missing files to the `include` list
- Check that import paths match the file structure
- Verify that the image includes all necessary packages
**File not found at runtime:**
- Ensure all referenced files are included
- Check mount paths for file/directory inputs
- Verify file paths are relative to the app root directory
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/app-usage-patterns ===
# App usage patterns
Apps and tasks can interact in various ways: calling each other via HTTP, webhooks, WebSockets, or direct browser usage. This page describes the different patterns and when to use them.
## Patterns overview
1. **Build apps > App usage patterns > Call app from task**: A task makes HTTP requests to an app
2. **Build apps > App usage patterns > Call task from app (webhooks / APIs)**: An app triggers task execution via the Flyte SDK
3. **Build apps > App usage patterns > Call app from app**: One app makes HTTP requests to another app
4. **Build apps > App usage patterns > WebSocket-based patterns**: Real-time, bidirectional communication
5. **Browser-based access**: Users access apps directly through the browser
## Call app from task
Tasks can call apps by making HTTP requests to the app's endpoint. This is useful when:
- You need to use a long-running service during task execution
- You want to call a model serving endpoint from a batch processing task
- You need to interact with an API from a workflow
### Example: Task calling an app
```
"""Example of a task calling an app."""
import httpx
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(title="Add One", description="Adds one to the input", version="1.0.0")
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{docs-fragment app-definition}}
app_env = FastAPIAppEnvironment(
name="add-one-app",
app=app,
description="Adds one to the input",
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment app-definition}}
# {{docs-fragment task-env}}
task_env = flyte.TaskEnvironment(
name="add_one_task_env",
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
depends_on=[app_env], # Ensure app is deployed before task runs
)
# {{/docs-fragment task-env}}
# {{docs-fragment app-endpoint}}
@app.get("/")
async def add_one(x: int) -> dict[str, int]:
"""Main endpoint for the add-one app."""
return {"result": x + 1}
# {{/docs-fragment app-endpoint}}
# {{docs-fragment task}}
@task_env.task
async def add_one_task(x: int) -> int:
print(f"Calling app at {app_env.endpoint}")
async with httpx.AsyncClient() as client:
response = await client.get(app_env.endpoint, params={"x": x})
response.raise_for_status()
return response.json()["result"]
# {{/docs-fragment task}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/task_calling_app.py)
Key points:
- The task environment uses `depends_on=[app_env]` to ensure the app is deployed first
- Access the app endpoint via `app_env.endpoint`
- Use standard HTTP client libraries (like `httpx`) to make requests
## Call task from app (webhooks / APIs)
Apps can trigger task execution using the Flyte SDK. This is useful for:
- Webhooks that trigger workflows
- APIs that need to run batch jobs
- Services that need to execute tasks asynchronously
Webhooks are HTTP endpoints that trigger actions in response to external events. Flyte apps can serve as webhook endpoints that trigger task runs, workflows, or other operations.
### Example: Basic webhook app
Here's a simple webhook that triggers Flyte tasks:
```
"""A webhook that triggers Flyte tasks."""
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment auth}}
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if credentials.credentials != WEBHOOK_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
# {{/docs-fragment auth}}
# {{docs-fragment lifespan}}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize Flyte before accepting requests."""
await flyte.init_in_cluster.aio()
yield
# Cleanup if needed
# {{/docs-fragment lifespan}}
# {{docs-fragment app}}
app = FastAPI(
title="Flyte Webhook Runner",
description="A webhook service that triggers Flyte task runs",
version="1.0.0",
lifespan=lifespan,
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
# {{/docs-fragment app}}
# {{docs-fragment webhook-endpoint}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
project: str,
domain: str,
name: str,
version: str,
inputs: dict,
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
"""
Trigger a Flyte task run via webhook.
Returns information about the launched run.
"""
# Fetch the task
task = await remote.TaskDetails.fetch(
project=project,
domain=domain,
name=name,
version=version,
)
# Run the task
run = await flyte.run.aio(task, **inputs)
return {
"url": run.url,
"id": run.id,
"status": "started",
}
# {{/docs-fragment webhook-endpoint}}
# {{docs-fragment env}}
env = FastAPIAppEnvironment(
name="webhook-runner",
app=app,
description="A webhook service that triggers Flyte task runs",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False, # We handle auth in the app
env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
# {{/docs-fragment env}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/webhook/basic_webhook.py)
Once deployed, you can trigger tasks via HTTP POST:
```bash
curl -X POST "https://your-webhook-url/run-task/flytesnacks/development/my_task/v1" \
-H "Authorization: Bearer test-api-key" \
-H "Content-Type: application/json" \
-d '{"input_key": "input_value"}'
```
Response:
```json
{
"url": "https://console.union.ai/...",
"id": "abc123",
"status": "started"
}
```
### Advanced webhook patterns
**Webhook with validation**
Use Pydantic for input validation:
```python
from pydantic import BaseModel
class TaskInput(BaseModel):
data: dict
priority: int = 0
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
project: str,
domain: str,
name: str,
version: str,
inputs: TaskInput, # Validated input
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
task = await remote.TaskDetails.fetch(
project=project,
domain=domain,
name=name,
version=version,
)
run = await flyte.run.aio(task, **inputs.dict())
return {
"run_id": run.id,
"url": run.url,
}
```
**Webhook with response waiting**
Wait for task completion:
```python
@app.post("/run-task-and-wait/{project}/{domain}/{name}/{version}")
async def run_task_and_wait(
project: str,
domain: str,
name: str,
version: str,
inputs: dict,
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
task = await remote.TaskDetails.fetch(
project=project,
domain=domain,
name=name,
version=version,
)
run = await flyte.run.aio(task, **inputs)
run.wait() # Wait for completion
return {
"run_id": run.id,
"url": run.url,
"status": run.status,
"outputs": run.outputs(),
}
```
**Webhook with secret management**
Use Flyte secrets for API keys:
```python
env = FastAPIAppEnvironment(
name="webhook-runner",
app=app,
secrets=flyte.Secret(key="webhook-api-key", as_env_var="WEBHOOK_API_KEY"),
# ...
)
```
Then access in your app:
```python
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
```
### Webhook security and best practices
- **Authentication**: Always secure webhooks with authentication (API keys, tokens, etc.).
- **Input validation**: Validate webhook inputs using Pydantic models.
- **Error handling**: Handle errors gracefully and return meaningful error messages.
- **Async operations**: Use async/await for I/O operations.
- **Health checks**: Include health check endpoints.
- **Logging**: Log webhook requests for debugging and auditing.
- **Rate limiting**: Consider implementing rate limiting for production.
Security considerations:
- Store API keys in Flyte secrets, not in code.
- Always use HTTPS in production.
- Validate all inputs to prevent injection attacks.
- Implement proper access control mechanisms.
- Log all webhook invocations for security auditing.
### Example: GitHub webhook
Here's an example webhook that triggers tasks based on GitHub events:
```python
from fastapi import FastAPI, Request, Header
import hmac
import hashlib
app = FastAPI(title="GitHub Webhook Handler")
@app.post("/github-webhook")
async def github_webhook(
request: Request,
x_hub_signature_256: str = Header(None),
):
"""Handle GitHub webhook events."""
body = await request.body()
# Verify signature
secret = os.getenv("GITHUB_WEBHOOK_SECRET")
signature = hmac.new(
secret.encode(),
body,
hashlib.sha256
).hexdigest()
expected_signature = f"sha256={signature}"
if not hmac.compare_digest(x_hub_signature_256, expected_signature):
raise HTTPException(status_code=403, detail="Invalid signature")
# Process webhook
event = await request.json()
event_type = request.headers.get("X-GitHub-Event")
if event_type == "push":
# Trigger deployment task
task = await remote.TaskDetails.fetch(...)
run = await flyte.run.aio(task, commit=event["after"])
return {"run_id": run.id, "url": run.url}
return {"status": "ignored"}
```
## Call app from app
Apps can call other apps by making HTTP requests. This is useful for:
- Microservice architectures
- Proxy/gateway patterns
- A/B testing setups
- Service composition
### Example: App calling another app
```python
import httpx
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# Backend app
app1 = FastAPI(title="Backend API")
env1 = FastAPIAppEnvironment(
name="backend-api",
app=app1,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
# Frontend app that calls the backend
app2 = FastAPI(title="Frontend API")
env2 = FastAPIAppEnvironment(
name="frontend-api",
app=app2,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi", "uvicorn", "httpx"
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
depends_on=[env1], # Ensure backend is deployed first
)
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
"""Proxy that calls the backend app."""
async with httpx.AsyncClient() as client:
response = await client.get(f"{env1.endpoint}/greeting/{name}")
response.raise_for_status()
return response.json()
```
Key points:
- Use `depends_on=[env1]` to ensure dependencies are deployed first
- Access the app endpoint via `env1.endpoint`
- Use HTTP clients (like `httpx`) to make requests between apps
### Using AppEndpoint input
You can pass app endpoints as inputs for more flexibility:
```python
env2 = FastAPIAppEnvironment(
name="frontend-api",
app=app2,
inputs=[
flyte.app.Input(
name="backend_url",
value=flyte.app.AppEndpoint(app_name="backend-api"),
env_var="BACKEND_URL",
),
],
# ...
)
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str):
backend_url = os.getenv("BACKEND_URL")
async with httpx.AsyncClient() as client:
response = await client.get(f"{backend_url}/greeting/{name}")
return response.json()
```
## WebSocket-based patterns
WebSockets enable bidirectional, real-time communication between clients and servers. Flyte apps can serve WebSocket endpoints for real-time applications like chat, live updates, or streaming data.
### Example: Basic WebSocket app
Here's a simple FastAPI app with WebSocket support:
```
"""A FastAPI app with WebSocket support."""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import asyncio
import json
from datetime import UTC, datetime
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(
title="Flyte WebSocket Demo",
description="A FastAPI app with WebSocket support",
version="1.0.0",
)
# {{docs-fragment connection-manager}}
class ConnectionManager:
"""Manages WebSocket connections."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
"""Accept and register a new WebSocket connection."""
await websocket.accept()
self.active_connections.append(websocket)
print(f"Client connected. Total: {len(self.active_connections)}")
def disconnect(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
self.active_connections.remove(websocket)
print(f"Client disconnected. Total: {len(self.active_connections)}")
async def send_personal_message(self, message: str, websocket: WebSocket):
"""Send a message to a specific WebSocket connection."""
await websocket.send_text(message)
async def broadcast(self, message: str):
"""Broadcast a message to all active connections."""
for connection in self.active_connections:
try:
await connection.send_text(message)
except Exception as e:
print(f"Error broadcasting: {e}")
manager = ConnectionManager()
# {{/docs-fragment connection-manager}}
# {{docs-fragment websocket-endpoint}}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time communication."""
await manager.connect(websocket)
try:
# Send welcome message
await manager.send_personal_message(
json.dumps({
"type": "system",
"message": "Welcome! You are connected.",
"timestamp": datetime.now(UTC).isoformat(),
}),
websocket,
)
# Listen for messages
while True:
data = await websocket.receive_text()
# Echo back to sender
await manager.send_personal_message(
json.dumps({
"type": "echo",
"message": f"Echo: {data}",
"timestamp": datetime.now(UTC).isoformat(),
}),
websocket,
)
# Broadcast to all clients
await manager.broadcast(
json.dumps({
"type": "broadcast",
"message": f"Broadcast: {data}",
"timestamp": datetime.now(UTC).isoformat(),
"connections": len(manager.active_connections),
})
)
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(
json.dumps({
"type": "system",
"message": "A client disconnected",
"connections": len(manager.active_connections),
})
)
# {{/docs-fragment websocket-endpoint}}
# {{docs-fragment env}}
env = FastAPIAppEnvironment(
name="websocket-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"websockets",
),
resources=flyte.Resources(cpu=1, memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment env}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/basic_websocket.py)
### WebSocket patterns
**Echo server**
```python
@app.websocket("/echo")
async def echo(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
pass
```
**Broadcast server**
```python
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
```
**Real-time data streaming**
```python
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Generate or fetch data
data = {"timestamp": datetime.now().isoformat(), "value": random.random()}
await websocket.send_json(data)
await asyncio.sleep(1) # Send update every second
except WebSocketDisconnect:
pass
```
**Chat application**
```python
class ChatRoom:
def __init__(self, name: str):
self.name = name
self.connections: list[WebSocket] = []
async def join(self, websocket: WebSocket):
self.connections.append(websocket)
async def leave(self, websocket: WebSocket):
self.connections.remove(websocket)
async def broadcast(self, message: str, sender: WebSocket):
for connection in self.connections:
if connection != sender:
await connection.send_text(message)
rooms: dict[str, ChatRoom] = {}
@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
await websocket.accept()
if room_name not in rooms:
rooms[room_name] = ChatRoom(room_name)
room = rooms[room_name]
await room.join(websocket)
try:
while True:
data = await websocket.receive_text()
await room.broadcast(data, websocket)
except WebSocketDisconnect:
await room.leave(websocket)
```
### Using WebSockets with Flyte tasks
You can trigger Flyte tasks from WebSocket messages:
```python
@app.websocket("/task-runner")
async def task_runner(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Receive task request
message = await websocket.receive_text()
request = json.loads(message)
# Trigger Flyte task
task = await remote.TaskDetails.fetch(
project=request["project"],
domain=request["domain"],
name=request["task"],
version=request["version"],
)
run = await flyte.run.aio(task, **request["inputs"])
# Send run info back
await websocket.send_json({
"run_id": run.id,
"url": run.url,
"status": "started",
})
# Optionally stream updates
async for update in run.stream():
await websocket.send_json({
"status": update.status,
"message": update.message,
})
except WebSocketDisconnect:
pass
```
### WebSocket client example
Connect from Python:
```python
import asyncio
import websockets
import json
async def client():
uri = "ws://your-app-url/ws"
async with websockets.connect(uri) as websocket:
# Send message
await websocket.send("Hello, Server!")
# Receive message
response = await websocket.recv()
print(f"Received: {response}")
asyncio.run(client())
```
## Browser-based apps
For browser-based apps (like Streamlit), users interact directly through the web interface. The app URL is accessible in a browser, and users interact with the UI directly - no API calls needed from other services.
To access a browser-based app:
1. Deploy the app
2. Navigate to the app URL in a browser
3. Interact with the UI directly
## Best practices
1. **Use `depends_on`**: Always specify dependencies to ensure proper deployment order.
2. **Handle errors**: Implement proper error handling for HTTP requests.
3. **Use async clients**: Use async HTTP clients (`httpx.AsyncClient`) in async contexts.
4. **Initialize Flyte**: For apps calling tasks, initialize Flyte in the app's startup.
5. **Endpoint access**: Use `app_env.endpoint` or `AppEndpoint` input for accessing app URLs.
6. **Authentication**: Consider authentication when apps call each other (set `requires_auth=True` if needed).
7. **Webhook security**: Secure webhooks with auth, validation, and HTTPS.
8. **WebSocket robustness**: Implement connection management, heartbeats, and rate limiting.
## Summary
| Pattern | Use Case | Implementation |
|---------|----------|----------------|
| Task β App | Batch processing using inference services | HTTP requests from task |
| App β Task | Webhooks, APIs triggering workflows | Flyte SDK in app |
| App β App | Microservices, proxies, agent routers, LLM routers | HTTP requests between apps |
| Browser β App | User-facing dashboards | Direct browser access |
Choose the pattern that best fits your architecture and requirements.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/secret-based-authentication ===
# Secret-based authentication
In this guide, we'll deploy a FastAPI app that uses API key authentication with Flyte secrets. This allows you to invoke the endpoint from the public internet in a secure manner without exposing API keys in your code.
## Create the secret
Before defining and deploying the app, you need to create the `API_KEY` secret in Flyte. This secret will store your API key securely.
Create the secret using the Flyte CLI:
```bash
flyte create secret API_KEY
```
For example:
```bash
flyte create secret API_KEY my-secret-api-key-12345
```
> [!NOTE]
> The secret name `API_KEY` must match the key specified in the `flyte.Secret()` call in your code. The secret will be available to your app as the environment variable specified in `as_env_var`.
## Define the FastAPI app
Here's a simple FastAPI app that uses `HTTPAuthorizationCredentials` to authenticate requests using a secret stored in Flyte:
```
"""Basic FastAPI authentication using dependency injection."""
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# Get API key from environment variable (loaded from Flyte secret)
# The secret must be created using: flyte create secret API_KEY
API_KEY = os.getenv("API_KEY")
security = HTTPBearer()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
"""Verify the API key from the bearer token."""
if not API_KEY:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="API_KEY not configured",
)
if credentials.credentials != API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
return credentials
app = FastAPI(title="Authenticated API")
@app.get("/public")
async def public_endpoint():
"""Public endpoint that doesn't require authentication."""
return {"message": "This is public"}
@app.get("/protected")
async def protected_endpoint(
credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
"""Protected endpoint that requires authentication."""
return {
"message": "This is protected",
"user": credentials.credentials,
}
env = FastAPIAppEnvironment(
name="authenticated-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False, # We handle auth in the app
secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY"),
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"API URL: {app_deployment[0].url}")
print(f"Swagger docs: {app_deployment[0].url}/docs")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/basic_auth.py)
As you can see, we:
1. Define a `FastAPI` app
2. Create a `verify_token` function that verifies the API key from the Bearer token
3. Define endpoints that use the `verify_token` function to authenticate requests
4. Configure the `FastAPIAppEnvironment` with:
- `requires_auth=False` - This allows the endpoint to be reached without going through Flyte's authentication, since we're handling authentication ourselves using the `API_KEY` secret
- `secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY")` - This injects the secret value into the `API_KEY` environment variable at runtime
The key difference from using `env_vars` is that secrets are stored securely in Flyte's secret store and injected at runtime, rather than being passed as plain environment variables.
## Deploy the FastAPI app
Once the secret is created, you can deploy the FastAPI app. Make sure your `config.yaml` file is in the same directory as your script, then run:
```bash
python basic_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve basic_auth.py
```
Deploying the application will stream the status to the console and display the app URL:
```
β¨ Deploying Application: authenticated-api
π Console URL: https:///console/projects/my-project/domains/development/apps/fastapi-with-auth
[Status] Pending: App is pending deployment
[Status] Started: Service is ready
π Deployed Endpoint: https://rough-meadow-97cf5.apps.
```
## Invoke the endpoint
Once deployed, you can invoke the authenticated endpoint using curl:
```bash
curl -X GET "https://rough-meadow-97cf5.apps./protected" \
-H "Authorization: Bearer "
```
Replace `` with the actual API key value you used when creating the secret.
For example, if you created the secret with value `my-secret-api-key-12345`:
```bash
curl -X GET "https://rough-meadow-97cf5.apps./protected" \
-H "Authorization: Bearer my-secret-api-key-12345"
```
You should receive a response:
```json
{
"message": "This is protected",
"user": "my-secret-api-key-12345"
}
```
## Authentication for vLLM and SGLang apps
Both vLLM and SGLang apps support API key authentication through their native `--api-key` argument. This allows you to secure your LLM endpoints while keeping them accessible from the public internet.
### Create the authentication secret
Create a secret to store your API key:
```bash
flyte create secret AUTH_SECRET
```
For example:
```bash
flyte create secret AUTH_SECRET my-llm-api-key-12345
```
### Deploy vLLM app with authentication
Here's how to deploy a vLLM app with API key authentication:
```
"""vLLM app with API key authentication."""
import pathlib
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# The secret must be created using: flyte create secret AUTH_SECRET
vllm_app = VLLMAppEnvironment(
name="vllm-app-with-auth",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by vLLM
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# Disable Union's platform-level authentication so you can access the
# endpoint from the public internet
requires_auth=False,
# Inject the secret as an environment variable
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
# Pass the API key to vLLM's --api-key argument
# The $AUTH_SECRET will be replaced with the actual secret value at runtime
extra_args=[
"--api-key", "$AUTH_SECRET",
],
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(vllm_app)
print(f"Deployed vLLM app: {app.url}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_with_auth.py)
Key points:
1. **`requires_auth=False`** - Disables Union's platform-level authentication so the endpoint can be accessed from the public internet
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to vLLM's `--api-key` argument. The `$AUTH_SECRET` will be replaced with the actual secret value at runtime
Deploy the app:
```bash
python vllm_with_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve vllm_with_auth.py
```
### Deploy SGLang app with authentication
Here's how to deploy a SGLang app with API key authentication:
```
"""SGLang app with API key authentication."""
import pathlib
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# The secret must be created using: flyte create secret AUTH_SECRET
sglang_app = SGLangAppEnvironment(
name="sglang-app-with-auth",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by SGLang
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
# Disable Union's platform-level authentication so you can access the
# endpoint from the public internet
requires_auth=False,
# Inject the secret as an environment variable
secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
# Pass the API key to SGLang's --api-key argument
# The $AUTH_SECRET will be replaced with the actual secret value at runtime
extra_args=[
"--api-key", "$AUTH_SECRET",
],
)
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.serve(sglang_app)
print(f"Deployed SGLang app: {app.url}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_with_auth.py)
The configuration is similar to vLLM:
1. **`requires_auth=False`** - Disables Union's platform-level authentication
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to SGLang's `--api-key` argument
Deploy the app:
```bash
python sglang_with_auth.py
```
Or use the Flyte CLI:
```bash
flyte serve sglang_with_auth.py
```
### Invoke authenticated LLM endpoints
Once deployed, you can invoke the authenticated endpoints using the OpenAI-compatible API format. Both vLLM and SGLang expose OpenAI-compatible endpoints.
For example, to make a chat completion request:
```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer " \
-d '{
"model": "qwen3-0.6b",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
Replace `` with the actual API key value you used when creating the secret.
For example, if you created the secret with value `my-llm-api-key-12345`:
```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer my-llm-api-key-12345" \
-d '{
"model": "qwen3-0.6b",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}'
```
You should receive a response with the model's completion.
> [!NOTE]
> The `$AUTH_SECRET` syntax in `extra_args` is automatically replaced with the actual secret value at runtime. This ensures the API key is never exposed in your code or configuration files.
## Accessing Swagger documentation
The app also includes a public health check endpoint and Swagger UI documentation:
- **Health check**: `https://your-app-url/health`
- **Swagger UI**: `https://your-app-url/docs`
- **ReDoc**: `https://your-app-url/redoc`
The Swagger UI will show an "Authorize" button where you can enter your Bearer token to test authenticated endpoints directly from the browser.
## Security best practices
1. **Use strong API keys**: Generate cryptographically secure random strings for your API keys
2. **Rotate keys regularly**: Periodically rotate your API keys for better security
3. **Scope secrets appropriately**: Use project/domain scoping when creating secrets if you want to limit access:
```bash
flyte create secret --project my-project --domain development API_KEY my-secret-value
```
4. **Never commit secrets**: Always use Flyte secrets for API keys, never hardcode them in your code
5. **Use HTTPS**: Always use HTTPS in production (Flyte apps are served over HTTPS by default)
## Troubleshooting
**Authentication failing:**
- Verify the secret exists: `flyte get secret API_KEY`
- Check that the secret key name matches exactly (case-sensitive)
- Ensure you're using the correct Bearer token value
- Verify the `as_env_var` parameter matches the environment variable name in your code
**Secret not found:**
- Make sure you've created the secret before deploying the app
- Check the secret scope (organization vs project/domain) matches your app's project/domain
- Verify the secret name matches exactly (should be `API_KEY`)
**App not starting:**
- Check container logs for errors
- Verify all dependencies are installed in the image
- Ensure the secret is accessible in the app's project/domain
**LLM app authentication not working:**
- Verify the secret exists: `flyte get secret AUTH_SECRET`
- Check that `$AUTH_SECRET` is correctly specified in `extra_args` (note the `$` prefix)
- Ensure the secret name matches exactly (case-sensitive) in both the `flyte.Secret()` call and `extra_args`
- For vLLM, verify the `--api-key` argument is correctly passed
- For SGLang, verify the `--api-key` argument is correctly passed
- Check that `requires_auth=False` is set to allow public access
## Next steps
- Learn more about **Configure tasks > Secrets** in Flyte
- See **Build apps > Secret-based authentication > app usage patterns** for webhook examples and authentication patterns
- Learn about **Build apps > vLLM app** and **Build apps > SGLang app** for serving LLMs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/streamlit-app ===
# Streamlit app
Streamlit is a popular framework for building interactive web applications and dashboards. Flyte makes it easy to deploy Streamlit apps as long-running services.
## Basic Streamlit app
The simplest way to deploy a Streamlit app is to use the built-in Streamlit "hello" demo:
```
"""A basic Streamlit app using the built-in hello demo."""
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1")
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="streamlit-hello",
image=image,
command="streamlit hello --server.port 8080",
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.deploy(app_env)
print(f"App URL: {app[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/basic_streamlit.py)
## Custom Streamlit app
For a custom Streamlit app, use the `include` parameter to bundle your app files:
```
"""A custom Streamlit app with multiple files."""
import pathlib
import flyte
import flyte.app
# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
# {{/docs-fragment image}}
# {{docs-fragment app-env}}
app_env = flyte.app.AppEnvironment(
name="streamlit-custom-app",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py"], # Include your app files
resources=flyte.Resources(cpu="1", memory="1Gi"),
requires_auth=False,
)
# {{/docs-fragment app-env}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app = flyte.deploy(app_env)
print(f"App URL: {app[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/custom_streamlit.py)
Your `main.py` file would contain your Streamlit app code:
```
import os
import streamlit as st
from utils import generate_data
# {{docs-fragment streamlit-app}}
all_columns = ["Apples", "Orange", "Pineapple"]
with st.container(border=True):
columns = st.multiselect("Columns", all_columns, default=all_columns)
all_data = st.cache_data(generate_data)(columns=all_columns, seed=101)
data = all_data[columns]
tab1, tab2 = st.tabs(["Chart", "Dataframe"])
tab1.line_chart(data, height=250)
tab2.dataframe(data, height=250, use_container_width=True)
st.write(f"Environment: {os.environ}")
# {{/docs-fragment streamlit-app}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/main.py)
## Multi-file Streamlit app
For apps with multiple files, include all necessary files:
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-multi-file",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py", "utils.py", "components.py"], # Include all files
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
```
Structure your project like this:
```
project/
βββ main.py # Main Streamlit app
βββ utils.py # Utility functions
βββ components.py # Reusable components
```
## Example: Data visualization dashboard
Here's a complete example of a Streamlit dashboard:
```python
# main.py
import streamlit as st
import pandas as pd
import numpy as np
st.title("Sales Dashboard")
# Load data
@st.cache_data
def load_data():
return pd.DataFrame({
"date": pd.date_range("2024-01-01", periods=100, freq="D"),
"sales": np.random.randint(1000, 5000, 100),
})
data = load_data()
# Sidebar filters
st.sidebar.header("Filters")
start_date = st.sidebar.date_input("Start date", value=data["date"].min())
end_date = st.sidebar.date_input("End date", value=data["date"].max())
# Filter data
filtered_data = data[
(data["date"] >= pd.Timestamp(start_date)) &
(data["date"] <= pd.Timestamp(end_date))
]
# Display metrics
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Sales", f"${filtered_data['sales'].sum():,.0f}")
with col2:
st.metric("Average Sales", f"${filtered_data['sales'].mean():,.0f}")
with col3:
st.metric("Days", len(filtered_data))
# Chart
st.line_chart(filtered_data.set_index("date")["sales"])
```
Deploy with:
```python
import flyte
import flyte.app
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"streamlit==1.41.1",
"pandas==2.2.3",
"numpy==2.2.3",
)
app_env = flyte.app.AppEnvironment(
name="sales-dashboard",
image=image,
args="streamlit run main.py --server.port 8080",
port=8080,
include=["main.py"],
resources=flyte.Resources(cpu="2", memory="2Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.deploy(app_env)
print(f"Dashboard URL: {app[0].url}")
```
## Custom domain
You can use a custom subdomain for your Streamlit app:
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-app",
image=image,
command="streamlit hello --server.port 8080",
port=8080,
domain=flyte.app.Domain(subdomain="dashboard"), # Custom subdomain
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
```
## Best practices
1. **Use `include` for custom apps**: Always include your app files when deploying custom Streamlit code
2. **Set the port correctly**: Ensure your Streamlit app uses `--server.port 8080` (or match your `port` setting)
3. **Cache data**: Use `@st.cache_data` for expensive computations to improve performance
4. **Resource sizing**: Adjust resources based on your app's needs (data size, computations)
5. **Public vs private**: Set `requires_auth=False` for public dashboards, `True` for internal tools
## Troubleshooting
**App not loading:**
- Verify the port matches (use `--server.port 8080`)
- Check that all required files are included
- Review container logs for errors
**Missing dependencies:**
- Ensure all required packages are in your image's pip packages
- Check that file paths in `include` are correct
**Performance issues:**
- Increase CPU/memory resources
- Use Streamlit's caching features (`@st.cache_data`, `@st.cache_resource`)
- Optimize data processing
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/fastapi-app ===
# FastAPI app
FastAPI is a modern, fast web framework for building APIs. Flyte provides `FastAPIAppEnvironment` which makes it easy to deploy FastAPI applications.
## Basic FastAPI app
Here's a simple FastAPI app:
```
"""A basic FastAPI app example."""
from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment fastapi-app}}
app = FastAPI(
title="My API",
description="A simple FastAPI application",
version="1.0.0",
)
# {{/docs-fragment fastapi-app}}
# {{docs-fragment fastapi-env}}
env = FastAPIAppEnvironment(
name="my-fastapi-app",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
# {{/docs-fragment fastapi-env}}
# {{docs-fragment endpoints}}
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# {{/docs-fragment endpoints}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(env)
print(f"Deployed: {app_deployment[0].url}")
print(f"API docs: {app_deployment[0].url}/docs")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/basic_fastapi.py)
Once deployed, you can:
- Access the API at the generated URL
- View interactive API docs at `/docs` (Swagger UI)
- View alternative docs at `/redoc`
## Serving a machine learning model
Here's an example of serving a scikit-learn model:
```python
import os
from contextlib import asynccontextmanager
import joblib
import flyte
from fastapi import FastAPI
from flyte.app.extras import FastAPIAppEnvironment
from pydantic import BaseModel
app = FastAPI(title="ML Model API")
# Define request/response models
class PredictionRequest(BaseModel):
feature1: float
feature2: float
feature3: float
class PredictionResponse(BaseModel):
prediction: float
probability: float
# Load model (you would typically load this from storage)
model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
model_path = os.getenv("MODEL_PATH", "/app/models/model.joblib")
# In production, load from your storage
with open(model_path, "rb") as f:
model = joblib.load(f)
yield
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
# Make prediction
# prediction = model.predict([[request.feature1, request.feature2, request.feature3]])
# Dummy prediction for demo
prediction = 0.85
probability = 0.92
return PredictionResponse(
prediction=prediction,
probability=probability,
)
env = FastAPIAppEnvironment(
name="ml-model-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"scikit-learn",
"pydantic",
"joblib",
),
inputs=[
flyte.app.Input(
name="model_file",
value=flyte.io.File("s3://bucket/models/model.joblib"),
mount="/app/models",
env_var="MODEL_PATH",
),
]
resources=flyte.Resources(cpu=2, memory="2Gi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
app_deployment = flyte.deploy(env)
print(f"API URL: {app_deployment[0].url}")
print(f"Swagger docs: {app_deployment[0].url}/docs")
```
## Accessing Swagger documentation
FastAPI automatically generates interactive API documentation. Once deployed:
- **Swagger UI**: Access at `{app_url}/docs`
- **ReDoc**: Access at `{app_url}/redoc`
- **OpenAPI JSON**: Access at `{app_url}/openapi.json`
The Swagger UI provides an interactive interface where you can:
- See all available endpoints
- Test API calls directly from the browser
- View request/response schemas
- See example payloads
## Example: REST API with multiple endpoints
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI(title="Product API")
# Data models
class Product(BaseModel):
id: int
name: str
price: float
class ProductCreate(BaseModel):
name: str
price: float
# In-memory database (use real database in production)
products_db = []
@app.get("/products", response_model=List[Product])
async def get_products():
return products_db
@app.get("/products/{product_id}", response_model=Product)
async def get_product(product_id: int):
product = next((p for p in products_db if p["id"] == product_id), None)
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return product
@app.post("/products", response_model=Product)
async def create_product(product: ProductCreate):
new_product = {
"id": len(products_db) + 1,
"name": product.name,
"price": product.price,
}
products_db.append(new_product)
return new_product
env = FastAPIAppEnvironment(
name="product-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
app_deployment = flyte.deploy(env)
print(f"API URL: {app_deployment[0].url}")
print(f"Swagger docs: {app_deployment[0].url}/docs")
```
## Multi-file FastAPI app
Here's an example of a multi-file FastAPI app:
```
"""Multi-file FastAPI app example."""
from fastapi import FastAPI
from module import function # Import from another file
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment
# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")
app_env = FastAPIAppEnvironment(
name="fastapi-multi-file",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
# FastAPIAppEnvironment automatically includes necessary files
# But you can also specify explicitly:
# include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}
# {{docs-fragment endpoint}}
@app.get("/")
async def root():
return function() # Uses function from module.py
# {{/docs-fragment endpoint}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_deployment = flyte.deploy(app_env)
print(f"Deployed: {app_deployment[0].url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py)
The helper module:
```
# {{docs-fragment helper-function}}
def function():
"""Helper function used by the FastAPI app."""
return {"message": "Hello from module.py!"}
# {{/docs-fragment helper-function}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/module.py)
See **Build apps > Multi-script apps** for more details on building FastAPI apps with multiple files.
## Best practices
1. **Use Pydantic models**: Define request/response models for type safety and automatic validation
2. **Handle errors**: Use HTTPException for proper error responses
3. **Async operations**: Use async/await for I/O operations
4. **Environment variables**: Use environment variables for configuration
5. **Logging**: Add proper logging for debugging and monitoring
6. **Health checks**: Always include a `/health` endpoint
7. **API documentation**: FastAPI auto-generates docs, but add descriptions to your endpoints
## Advanced features
FastAPI supports many features that work with Flyte:
- **Dependencies**: Use FastAPI's dependency injection system
- **Background tasks**: Run background tasks with BackgroundTasks
- **WebSockets**: See **Build apps > App usage patterns > WebSocket-based patterns** for details
- **Authentication**: Add authentication middleware (see **Build apps > Secret-based authentication**)
- **CORS**: Configure CORS for cross-origin requests
- **Rate limiting**: Add rate limiting middleware
## Troubleshooting
**App not starting:**
- Check that uvicorn can find your app module
- Verify all dependencies are installed in the image
- Check container logs for startup errors
**Import errors:**
- Ensure all imported modules are available
- Use `include` parameter if you have custom modules
- Check that file paths are correct
**API not accessible:**
- Verify `requires_auth` setting
- Check that the app is listening on the correct port (8080)
- Review network/firewall settings
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/vllm-app ===
# vLLM app
vLLM is a high-performance library for serving large language models (LLMs). Flyte provides `VLLMAppEnvironment` for deploying vLLM model servers.
## Installation
First, install the vLLM plugin:
```bash
pip install --pre flyteplugins-vllm
```
## Basic vLLM app
Here's a simple example serving a HuggingFace model:
```
"""A simple vLLM app example."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment basic-vllm-app}}
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by vLLM
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
requires_auth=False,
)
# {{/docs-fragment basic-vllm-app}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(vllm_app)
print(f"Deployed vLLM app: {app.url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/basic_vllm.py)
## Using prefetched models
You can use models prefetched with `flyte.prefetch`:
```
"""vLLM app using prefetched models."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment prefetch}}
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_hf_path="Qwen/Qwen3-0.6B", # this is a placeholder
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1", disk="10Gi"),
stream_model=True, # Stream model directly from blob store to GPU
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
# Prefetch the model first
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
app = flyte.serve(
vllm_app.clone_with(
vllm_app.name,
model_hf_path=None,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
)
)
print(f"Deployed vLLM app: {app.url}")
# {{/docs-fragment prefetch}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_with_prefetch.py)
## Model streaming
`VLLMAppEnvironment` supports streaming models directly from blob storage to GPU memory, reducing startup time.
When `stream_model=True` and the `model_path` argument is provided with either a `flyte.io.Dir` or `RunOutput` pointing
to a path in object store:
- Model weights stream directly from storage to GPU
- Faster startup time (no full download required)
- Lower disk space requirements
> [!NOTE]
> The contents of the model directory must be compatible with the vLLM-supported formats, e.g. the HuggingFace model
> serialization format.
## Custom vLLM arguments
Use `extra_args` to pass additional arguments to vLLM:
```python
vllm_app = VLLMAppEnvironment(
name="custom-vllm-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
extra_args=[
"--max-model-len", "8192", # Maximum context length
"--gpu-memory-utilization", "0.8", # GPU memory utilization
"--trust-remote-code", # Trust remote code in models
],
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
# ...
)
```
See the [vLLM documentation](https://docs.vllm.ai/en/stable/configuration/engine_args.html) for all available arguments.
## Using the OpenAI-compatible API
Once deployed, your vLLM app exposes an OpenAI-compatible API:
```python
from openai import OpenAI
client = OpenAI(
base_url="https://your-app-url/v1", # vLLM endpoint
api_key="your-api-key", # If you passed an --api-key argument
)
response = client.chat.completions.create(
model="qwen3-0.6b", # Your model_id
messages=[
{"role": "user", "content": "Hello, how are you?"}
],
)
print(response.choices[0].message.content)
```
> [!TIP]
> If you passed an `--api-key` argument, you can use the `api_key` parameter to authenticate your requests.
> See **Build apps > Secret-based authentication > Authentication for vLLM and SGLang apps > Deploy vLLM app with authentication** for more details on how to pass auth secrets to your app.
## Multi-GPU inference (Tensor Parallelism)
For larger models, use multiple GPUs with tensor parallelism:
```
"""vLLM app with multi-GPU tensor parallelism."""
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# {{docs-fragment multi-gpu}}
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
model_hf_path="meta-llama/Llama-2-70b-hf",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # 4 GPUs for tensor parallelism
disk="100Gi",
),
extra_args=[
"--tensor-parallel-size", "4", # Use 4 GPUs
"--max-model-len", "4096",
"--gpu-memory-utilization", "0.9",
],
requires_auth=False,
)
# {{/docs-fragment multi-gpu}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_multi_gpu.py)
The `tensor-parallel-size` should match the number of GPUs specified in resources.
## Model sharding with prefetch
You can prefetch and shard models for multi-GPU inference:
```python
# Prefetch with sharding configuration
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=flyte.prefetch.ShardConfig(
engine="vllm",
args=flyte.prefetch.VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
)
run.wait()
# Use the sharded model
vllm_app = VLLMAppEnvironment(
name="sharded-llm-app",
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
model_id="llama-2-70b",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4", disk="100Gi"),
extra_args=["--tensor-parallel-size", "4"],
stream_model=True,
)
```
See **Serve and deploy apps > Prefetching models** for more details on sharding.
## Autoscaling
vLLM apps work well with autoscaling:
```python
vllm_app = VLLMAppEnvironment(
name="autoscaling-llm-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale to zero when idle
scaledown_after=600, # 10 minutes idle before scaling down
),
# ...
)
```
## Best practices
1. **Use prefetching**: Prefetch models for faster deployment and better reproducibility
2. **Enable streaming**: Use `stream_model=True` to reduce startup time and disk usage
3. **Right-size GPUs**: Match GPU memory to model size
4. **Configure memory utilization**: Use `--gpu-memory-utilization` to control memory usage
5. **Use tensor parallelism**: For large models, use multiple GPUs with `tensor-parallel-size`
6. **Set autoscaling**: Use appropriate idle TTL to balance cost and performance
7. **Limit context length**: Use `--max-model-len` for smaller models to reduce memory usage
## Troubleshooting
**Model loading fails:**
- Verify GPU memory is sufficient for the model
- Check that the model path or HuggingFace path is correct
- Review container logs for detailed error messages
**Out of memory errors:**
- Reduce `--max-model-len`
- Lower `--gpu-memory-utilization`
- Use a smaller model or more GPUs
**Slow startup:**
- Enable `stream_model=True` for faster loading
- Prefetch models before deployment
- Use faster storage backends
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/build-apps/sglang-app ===
# SGLang app
SGLang is a fast structured generation library for large language models (LLMs). Flyte provides `SGLangAppEnvironment` for deploying SGLang model servers.
## Installation
First, install the SGLang plugin:
```bash
pip install --pre flyteplugins-sglang
```
## Basic SGLang app
Here's a simple example serving a HuggingFace model:
```
"""A simple SGLang app example."""
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# {{docs-fragment basic-sglang-app}}
sglang_app = SGLangAppEnvironment(
name="my-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B", # HuggingFace model path
model_id="qwen3-0.6b", # Model ID exposed by SGLang
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1", # GPU required for LLM serving
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=300, # Scale down after 5 minutes of inactivity
),
requires_auth=False,
)
# {{/docs-fragment basic-sglang-app}}
# {{docs-fragment deploy}}
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(sglang_app)
print(f"Deployed SGLang app: {app.url}")
# {{/docs-fragment deploy}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/basic_sglang.py)
## Using prefetched models
You can use models prefetched with `flyte.prefetch`:
```
"""SGLang app using prefetched models."""
from flyteplugins.sglang import SGLangAppEnvironment
import flyte
# {{docs-fragment prefetch}}
# Use the prefetched model
sglang_app = SGLangAppEnvironment(
name="my-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B", # this is a placeholder
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1", disk="10Gi"),
stream_model=True, # Stream model directly from blob store to GPU
requires_auth=False,
)
if __name__ == "__main__":
flyte.init_from_config()
# Prefetch the model first
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
app = flyte.serve(
sglang_app.clone_with(
sglang_app.name,
model_hf_path=None,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
)
)
print(f"Deployed SGLang app: {app.url}")
# {{/docs-fragment prefetch}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_with_prefetch.py)
## Model streaming
`SGLangAppEnvironment` supports streaming models directly from blob storage to GPU memory, reducing startup time.
When `stream_model=True` and the `model_path` argument is provided with either a `flyte.io.Dir` or `RunOutput` pointing
to a path in object store:
- Model weights stream directly from storage to GPU
- Faster startup time (no full download required)
- Lower disk space requirements
> [!NOTE]
> The contents of the model directory must be compatible with the SGLang-supported formats, e.g. the HuggingFace model
> serialization format.
## Custom SGLang arguments
Use `extra_args` to pass additional arguments to SGLang:
```python
sglang_app = SGLangAppEnvironment(
name="custom-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
extra_args=[
"--max-model-len", "8192", # Maximum context length
"--mem-fraction-static", "0.8", # Memory fraction for static allocation
"--trust-remote-code", # Trust remote code in models
],
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
# ...
)
```
See the [SGLang server arguments documentation](https://docs.sglang.io/advanced_features/server_arguments.html) for all available options.
## Using the OpenAI-compatible API
Once deployed, your SGLang app exposes an OpenAI-compatible API:
```python
from openai import OpenAI
client = OpenAI(
base_url="https://your-app-url/v1", # SGLang endpoint
api_key="your-api-key", # If you passed an --api-key argument
)
response = client.chat.completions.create(
model="qwen3-0.6b", # Your model_id
messages=[
{"role": "user", "content": "Hello, how are you?"}
],
)
print(response.choices[0].message.content)
```
> [!TIP]
> If you passed an `--api-key` argument, you can use the `api_key` parameter to authenticate your requests.
> See **Build apps > Secret-based authentication > Deploy SGLang app with authentication** for more details on how to pass auth secrets to your app.
## Multi-GPU inference (Tensor Parallelism)
For larger models, use multiple GPUs with tensor parallelism:
```python
sglang_app = SGLangAppEnvironment(
name="multi-gpu-sglang-app",
model_hf_path="meta-llama/Llama-2-70b-hf",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # 4 GPUs for tensor parallelism
disk="100Gi",
),
extra_args=[
"--tp", "4", # Tensor parallelism size (4 GPUs)
"--max-model-len", "4096",
"--mem-fraction-static", "0.9",
],
requires_auth=False,
)
```
The tensor parallelism size (`--tp`) should match the number of GPUs specified in resources.
## Model sharding with prefetch
You can prefetch and shard models for multi-GPU inference using SGLang's sharding:
```python
# Prefetch with sharding configuration
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=flyte.prefetch.ShardConfig(
engine="vllm",
args=flyte.prefetch.VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
)
run.wait()
# Use the sharded model
sglang_app = SGLangAppEnvironment(
name="sharded-sglang-app",
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
model_id="llama-2-70b",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4", disk="100Gi"),
extra_args=["--tp", "4"],
stream_model=True,
)
```
See **Serve and deploy apps > Prefetching models** for more details on sharding.
## Autoscaling
SGLang apps work well with autoscaling:
```python
sglang_app = SGLangAppEnvironment(
name="autoscaling-sglang-app",
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
scaling=flyte.app.Scaling(
replicas=(0, 1), # Scale to zero when idle
scaledown_after=600, # 10 minutes idle before scaling down
),
# ...
)
```
## Structured generation
SGLang is particularly well-suited for structured generation tasks. The deployed app supports standard OpenAI API calls, and you can use SGLang's advanced features through the API.
## Best practices
1. **Use prefetching**: Prefetch models for faster deployment and better reproducibility
2. **Enable streaming**: Use `stream_model=True` to reduce startup time and disk usage
3. **Right-size GPUs**: Match GPU memory to model size
4. **Use tensor parallelism**: For large models, use multiple GPUs with `--tp`
5. **Set autoscaling**: Use appropriate idle TTL to balance cost and performance
6. **Configure memory**: Use `--mem-fraction-static` to control memory allocation
7. **Limit context length**: Use `--max-model-len` for smaller models to reduce memory usage
## Troubleshooting
**Model loading fails:**
- Verify GPU memory is sufficient for the model
- Check that the model path or HuggingFace path is correct
- Review container logs for detailed error messages
**Out of memory errors:**
- Reduce `--max-model-len`
- Lower `--mem-fraction-static`
- Use a smaller model or more GPUs
**Slow startup:**
- Enable `stream_model=True` for faster loading
- Prefetch models before deployment
- Use faster storage backends
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps ===
# Serve and deploy apps
Flyte provides two main ways to deploy apps: **serve** (for development) and **deploy** (for production). This section covers both methods and their differences.
## Serve vs Deploy
### `flyte serve`
Serving is designed for development and iteration:
- **Dynamic input modification**: You can override app inputs when serving
- **Quick iteration**: Faster feedback loop for development
- **Interactive**: Better suited for testing and experimentation
### `flyte deploy`
Deployment is designed for production use:
- **Immutable**: Apps are deployed with fixed configurations
- **Production-ready**: Optimized for stability and reproducibility
## Using Python SDK
### Serve
```python
import flyte
import flyte.app
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"Served at: {app.url}")
```
### Deploy
```python
import flyte
import flyte.app
app_env = flyte.app.AppEnvironment(
name="my-app",
image=flyte.app.Image.from_debian_base().with_pip_packages("streamlit==1.41.1"),
args=["streamlit", "hello", "--server.port", "8080"],
port=8080,
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed app URL from the deployment
for deployed_env in deployments[0].envs.values():
print(f"Deployed: {deployed_env.deployed_app.url}")
```
## Using the CLI
### Serve
```bash
flyte serve path/to/app.py app_env
```
### Deploy
```bash
flyte deploy path/to/app.py app_env
```
## Next steps
- **Serve and deploy apps > How app serving works**: Understanding the serve process and configuration options
- **Serve and deploy apps > How app deployment works**: Understanding the deploy process and configuration options
- **Serve and deploy apps > Activating and deactivating apps**: Managing app lifecycle
- **Serve and deploy apps > Prefetching models**: Download and shard HuggingFace models for vLLM and SGLang
## Subpages
- **Serve and deploy apps > How app serving works**
- **Serve and deploy apps > How app deployment works**
- **Serve and deploy apps > Activating and deactivating apps**
- **Serve and deploy apps > Prefetching models**
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/how-app-serving-works ===
# How app serving works
Serving is the recommended way to deploy apps during development. It provides a faster feedback loop and allows you to dynamically modify inputs.
## Overview
When you serve an app, the following happens:
1. **Code bundling**: Your app code is bundled and prepared
2. **Image building**: Container images are built (if needed)
3. **Deployment**: The app is deployed to your Flyte cluster
4. **Activation**: The app is automatically activated and ready to use
5. **URL generation**: A URL is generated for accessing the app
## Using the Python SDK
The simplest way to serve an app:
```python
import flyte
import flyte.app
app_env = flyte.app.AppEnvironment(
name="my-dev-app",
inputs=[flyte.app.Input(name="model_path", value="s3://bucket/models/model.pkl")],
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
app = flyte.serve(app_env)
print(f"App served at: {app.url}")
```
## Overriding inputs
One key advantage of serving is the ability to override inputs dynamically:
```python
app = flyte.with_servecontext(
input_values={
"my-dev-app": {
"model_path": "s3://bucket/models/test-model.pkl",
}
}
).serve(app_env)
```
This is useful for:
- Testing different configurations
- Using different models or data sources
- A/B testing during development
## Advanced serving options
Use `with_servecontext()` for more control over the serving process:
```python
import flyte
app = flyte.with_servecontext(
version="v1.0.0",
project="my-project",
domain="development",
env_vars={"LOG_LEVEL": "DEBUG"},
input_values={"app-name": {"input": "value"}},
cluster_pool="dev-pool",
log_level=logging.INFO,
log_format="json",
dry_run=False,
).serve(app_env)
```
## Using CLI
You can also serve apps from the command line:
```bash
flyte serve path/to/app.py app
```
Where `app` is the variable name of the `AppEnvironment` object.
## Return value
`flyte.serve()` returns an `App` object with:
- `url`: The app's URL
- `endpoint`: The app's endpoint URL
- `status`: Current status of the app
- `name`: App name
```python
app = flyte.serve(app_env)
print(f"URL: {app.url}")
print(f"Endpoint: {app.endpoint}")
print(f"Status: {app.status}")
```
## Best practices
1. **Use for development**: App serving is ideal for development and testing.
2. **Override inputs**: Take advantage of input overrides for testing different configurations.
3. **Quick iteration**: Use `serve` for rapid development cycles.
4. **Switch to deploy**: Use **Serve and deploy apps > How app deployment works** for production deployments.
## Troubleshooting
**App not activating:**
- Check cluster connectivity
- Verify app configuration is correct
- Review container logs for errors
**Input overrides not working:**
- Verify input names match exactly
- Check that inputs are defined in the app environment
- Ensure you're using the `input_values` parameter correctly
**Slow serving:**
- Images may need to be built (first time is slower).
- Large code bundles can slow down deployment.
- Check network connectivity to the cluster.
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/how-app-deployment-works ===
# How app deployment works
Deployment is the recommended way to deploy apps to production. It creates versioned, immutable app deployments.
## Overview
When you deploy an app, the following happens:
1. **Code bundling**: Your app code is bundled and prepared
2. **Image building**: Container images are built (if needed)
3. **Deployment**: The app is deployed to your Flyte cluster
4. **Activation**: The app is automatically activated and ready to use
## Using the Python SDK
Deploy an app:
```python
import flyte
import flyte.app
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(app_env)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
```
`flyte.deploy()` returns a list of `Deployment` objects. Each `Deployment` contains a dictionary of `DeployedEnvironment` objects (one for each environment deployed, including environment dependencies). For apps, the `DeployedEnvironment` is a `DeployedAppEnvironment` which has a `deployed_app` property of type `App`.
## Deployment plan
Flyte automatically creates a deployment plan that includes:
- The app you're deploying
- All **Configure apps > Apps depending on other environments** (via `depends_on`)
- Proper deployment order
```python
app1_env = flyte.app.AppEnvironment(name="backend", ...)
app2_env = flyte.app.AppEnvironment(name="frontend", depends_on=[app1_env], ...)
# Deploying app2_env will also deploy app1_env
deployments = flyte.deploy(app2_env)
# deployments contains both app1_env and app2_env
assert len(deployments) == 2
```
## Overriding App configuration at deployment time
If you need to override the app configuration at deployment time, you can use the `clone_with` method to create a new
app environment with the desired overrides.
```python
app_env = flyte.app.AppEnvironment(name="my-app", ...)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env.clone_with(app_env.name, resources=flyte.Resources(cpu="2", memory="2Gi"))
)
for deployment in deployments:
for deployed_env in deployment.envs.values():
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {deployed_env.deployed_app.url}")
```
## Activation/deactivation
Unlike serving, deployment does not automatically activate apps. You need to activate them explicitly:
```python
deployments = flyte.deploy(app_env)
from flyte.remote import App
app = App.get(name=app_env.name)
# deactivate the app
app.deactivate()
# activate the app
app.activate()
```
See **Serve and deploy apps > Activating and deactivating apps** for more details.
## Using the CLI
Deploy from the command line:
```bash
flyte deploy path/to/app.py app
```
Where `app` is the variable name of the `AppEnvironment` object.
You can also specify the following options:
```bash
flyte deploy path/to/app.py app \
--version v1.0.0 \
--project my-project \
--domain production \
--dry-run
```
## Example: Full deployment configuration
```python
import flyte
import flyte.app
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(
app_env,
dryrun=False,
version="v1.0.0",
interactive_mode=False,
copy_style="loaded_modules",
)
# Access deployed apps from deployments
for deployment in deployments:
for deployed_env in deployment.envs.values():
app = deployed_env.deployed_app
print(f"Deployed: {deployed_env.env.name}")
print(f"URL: {app.url}")
# Activate the app
app.activate()
print(f"Activated: {app.name}")
```
## Best practices
1. **Use for production**: Deploy is designed for production use.
2. **Version everything**: Always specify versions for reproducibility.
3. **Test first**: Test with serve before deploying to production.
4. **Manage dependencies**: Use `depends_on` to manage app dependencies.
5. **Activation strategy**: Have a strategy for activating/deactivating apps.
6. **Rollback plan**: Keep old versions available for rollback.
7. **Use dry-run**: Test deployments with `dry_run=True` first.
8. **Separate environments**: Use different projects/domains for different environments.
9. **Input management**: Consider using environment-specific input values.
## Deployment status and return value
`flyte.deploy()` returns a list of `Deployment` objects. Each `Deployment` contains a dictionary of `DeployedEnvironment` objects:
```python
deployments = flyte.deploy(app_env)
for deployment in deployments:
for deployed_env in deployment.envs.values():
if hasattr(deployed_env, 'deployed_app'):
# Access deployed environment
env = deployed_env.env
app = deployed_env.deployed_app
# Access deployment info
print(f"Name: {env.name}")
print(f"URL: {app.url}")
print(f"Status: {app.deployment_status}")
```
For apps, each `DeployedAppEnvironment` includes:
- `env`: The `AppEnvironment` that was deployed
- `deployed_app`: The `App` object with properties like `url`, `endpoint`, `name`, and `deployment_status`
## Troubleshooting
**Deployment fails:**
- Check that all dependencies are available
- Verify image builds succeed
- Review deployment logs
**App not accessible:**
- Ensure the app is activated
- Check cluster connectivity
- Verify app configuration
**Version conflicts:**
- Use unique versions for each deployment
- Check existing app versions
- Clean up old versions if needed
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/activating-and-deactivating-apps ===
# Activating and deactivating apps
Apps deployed with `flyte.deploy()` need to be explicitly activated before they can serve traffic. Apps served with `flyte.serve()` are automatically activated.
## Activation
### Activate after deployment
After deploying an app, activate it:
```python
import flyte
from flyte.remote import App
# Deploy the app
deployments = flyte.deploy(app_env)
# Activate the app
app = App.get(name=app_env.name)
app.activate()
print(f"Activated app: {app.name}")
print(f"URL: {app.url}")
```
### Activate an app
When you get an app by name, you get the current app instance:
```python
app = App.get(name="my-app")
app.activate()
```
### Check activation status
Check if an app is active:
```python
app = App.get(name="my-app")
print(f"Active: {app.is_active()}")
print(f"Revision: {app.revision}")
```
## Deactivation
Deactivate an app when you no longer need it:
```python
app = App.get(name="my-app")
app.deactivate()
print(f"Deactivated app: {app.name}")
```
## Lifecycle management
### Typical deployment workflow
```python
# 1. Deploy new version
deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
# 2. Get the deployed app
new_app = App.get(name="my-app")
# Test endpoints, etc.
# 3. Activate the new version
new_app.activate()
print(f"Deployed and activated version {new_app.revision}")
```
### Blue-green deployment
For zero-downtime deployments:
```python
# Deploy new version without deactivating old
new_deployments = flyte.deploy(
app_env,
version="v2.0.0",
)
new_app = App.get(name="my-app")
# Test new version
# ... testing ...
# Switch traffic to new version
new_app.activate()
print(f"Activated revision {new_app.revision}")
```
### Rollback
Roll back to a previous version:
```python
# Deactivate current version
current_app = App.get(name="my-app")
current_app.deactivate()
print(f"Deactivated revision {current_app.revision}")
```
## Using CLI
### Activate
```bash
flyte update app --activate my-app
```
### Deactivate
```bash
flyte update app --deactivate my-app
```
### Check status
```bash
flyte app status my-app
```
## Best practices
1. **Activate after testing**: Test deployed apps before activating
2. **Version management**: Keep track of which version is active
3. **Rollback plan**: Always have a rollback strategy
4. **Blue-green deployments**: Use blue-green for zero-downtime
5. **Monitor**: Monitor apps after activation
6. **Cleanup**: Deactivate and remove old versions periodically
## Automatic activation with serve
Apps served with `flyte.serve()` are automatically activated:
```python
# Automatically activated
app = flyte.serve(app_env)
print(f"Active: {app.is_active()}") # True
```
This is convenient for development but less suitable for production where you want explicit control over activation.
## Example: Complete deployment and activation
```python
import flyte
import flyte.app
from flyte.remote import App
app_env = flyte.app.AppEnvironment(
name="my-prod-app",
# ... configuration ...
)
if __name__ == "__main__":
flyte.init_from_config()
# Deploy
deployments = flyte.deploy(
app_env,
version="v1.0.0",
project="my-project",
domain="production",
)
# Get the deployed app
app = App.get(name="my-prod-app")
# Activate
app.activate()
print(f"Deployed and activated: {app.name}")
print(f"Revision: {app.revision}")
print(f"URL: {app.url}")
print(f"Active: {app.is_active()}")
```
## Troubleshooting
**App not accessible after activation:**
- Verify activation succeeded
- Check app logs for startup errors
- Verify cluster connectivity
- Check that the app is listening on the correct port
**Activation fails:**
- Check that the app was deployed successfully
- Verify app configuration is correct
- Check cluster resources
- Review deployment logs
**Cannot deactivate:**
- Ensure you have proper permissions
- Check if there are dependencies preventing deactivation
- Verify the app name and version
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/serve-and-deploy-apps/prefetching-models ===
# Prefetching models
Prefetching allows you to download and prepare HuggingFace models (including sharding for multi-GPU inference) before
deploying **Build apps > vLLM app** or **Build apps > SGLang app** apps. This speeds up deployment and ensures models are ready when your app starts.
## Why prefetch?
Prefetching models provides several benefits:
- **Faster deployment**: Models are pre-downloaded, so apps start faster
- **Reproducibility**: Models are versioned and stored in Flyte's object store
- **Sharding support**: Pre-shard models for multi-GPU tensor parallelism
- **Cost efficiency**: Download once, use many times
- **Offline support**: Models are cached in your storage backend
## Basic prefetch
### Using Python SDK
```python
import flyte
# Prefetch a HuggingFace model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
# Wait for prefetch to complete
run.wait()
# Get the model path
model_path = run.outputs()[0].path
print(f"Model prefetched to: {model_path}")
```
### Using CLI
```bash
flyte prefetch hf-model Qwen/Qwen3-0.6B
```
Wait for completion:
```bash
flyte prefetch hf-model Qwen/Qwen3-0.6B --wait
```
## Using prefetched models
Use the prefetched model in your vLLM or SGLang app:
```python
from flyteplugins.vllm import VLLMAppEnvironment
import flyte
# Prefetch the model
run = flyte.prefetch.hf_model(repo="Qwen/Qwen3-0.6B")
run.wait()
# Use the prefetched model
vllm_app = VLLMAppEnvironment(
name="my-llm-app",
model_path=flyte.app.RunOutput(
type="directory",
run_name=run.name,
),
model_id="qwen3-0.6b",
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
stream_model=True,
)
app = flyte.serve(vllm_app)
```
> [!TIP]
> You can also use prefetched models as inputs to your generic `[[AppEnvironment]]`s or `FastAPIAppEnvironment`s.
## Prefetch options
### Custom artifact name
```python
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b-model", # Custom name for the stored model
)
```
### With HuggingFace token
If the model requires authentication:
```python
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of Flyte secret containing HF token
)
```
The default value for `hf_token_key` is `HF_TOKEN`, where `HF_TOKEN` is the name of the Flyte secret containing your
HuggingFace token. If this secret doesn't exist, you can create a secret using the **Configure tasks > Secrets**.
### With resources
By default, the prefetch task uses minimal resources (2 CPUs, 8GB of memory, 50Gi of disk storage), using
filestreaming logic to move the model weights from HuggingFace to your storage backend directly.
In some cases, the HuggingFace model may not support filestreaming, in which case the prefetch task will fallback to
downloading the model weights to the task pod's disk storage first, then uploading them to your storage backend. In this
case, you can specify custom resources for the prefetch task to override the default resources.
```python
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
cpu="4",
mem="16Gi",
ephemeral_storage="100Gi",
)
```
## Sharding models for multi-GPU
### vLLM sharding
Shard a model for tensor parallelism:
```python
from flyte.prefetch import ShardConfig, VLLMShardArgs
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
resources=flyte.Resources(cpu="8", memory="32Gi", gpu="L40s:4"),
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
trust_remote_code=True,
),
),
hf_token_key="HF_TOKEN",
)
run.wait()
```
Currently, the `flyte.prefetch.hf_model` function only supports sharding models
using the `vllm` engine. Once sharded, these models can be loaded with other
frameworks such as `transformers`, `torch`, or `sglang`.
### Using shard config via CLI
You can also use a YAML file for sharding configuration to use with the
`flyte prefetch hf-model` CLI command:
```yaml
# shard_config.yaml
engine: vllm
args:
tensor_parallel_size: 8
dtype: auto
trust_remote_code: true
```
Then run the CLI command:
```bash
flyte prefetch hf-model meta-llama/Llama-2-70b-hf \
--shard-config shard_config.yaml \
--accelerator L40s:8 \
--hf-token-key HF_TOKEN
```
## Using prefetched sharded models
After prefetching and sharding, serve the model in your app:
```python
# Use in vLLM app
vllm_app = VLLMAppEnvironment(
name="multi-gpu-llm-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="llama-2-70b",
resources=flyte.Resources(
cpu="8",
memory="32Gi",
gpu="L40s:4", # Match the number of GPUs used for sharding
),
extra_args=[
"--tensor-parallel-size", "4", # Match sharding config
],
)
if __name__ == "__main__":
# Prefetch with sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
accelerator="L40s:4",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=4),
),
)
run.wait()
flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
# override the model path to use the prefetched model
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
# set the hf_model_path to None
hf_model_path=None,
# stream the model from flyte object store directly to the GPU
stream_model=True,
)
)
```
## CLI options
Complete CLI usage:
```bash
flyte prefetch hf-model \
--artifact-name \
--architecture \
--task \
--modality text \
--format safetensors \
--model-type transformer \
--short-description "Description" \
--force 0 \
--wait \
--hf-token-key HF_TOKEN \
--cpu 4 \
--mem 16Gi \
--ephemeral-storage 100Gi \
--accelerator L40s:4 \
--shard-config shard_config.yaml
```
## Complete example
Here's a complete example of prefetching and using a model:
```python
import flyte
from flyteplugins.vllm import VLLMAppEnvironment
from flyte.prefetch import ShardConfig, VLLMShardArgs
# define the app environment
vllm_app = VLLMAppEnvironment(
name="qwen-serving-app",
# this will download the model from HuggingFace into the app container's filesystem
model_hf_path="Qwen/Qwen3-0.6B",
model_id="qwen3-0.6b",
resources=flyte.Resources(
cpu="4",
memory="16Gi",
gpu="L40s:1",
disk="10Gi",
),
scaling=flyte.app.Scaling(
replicas=(0, 1),
scaledown_after=600,
),
requires_auth=False,
)
if __name__ == "__main__":
# prefetch the model
print("Prefetching model...")
run = flyte.prefetch.hf_model(
repo="Qwen/Qwen3-0.6B",
artifact_name="qwen-0.6b",
cpu="4",
mem="16Gi",
ephemeral_storage="50Gi",
)
# wait for completion
print("Waiting for prefetch to complete...")
run.wait()
print(f"Model prefetched: {run.outputs()[0].path}")
# deploy the app
print("Deploying app...")
flyte.init_from_config()
app = flyte.serve(
vllm_app.clone_with(
name=vllm_app.name,
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
hf_model_path=None,
stream_model=True,
)
)
print(f"App deployed: {app.url}")
```
## Best practices
1. **Prefetch before deployment**: Prefetch models before deploying apps for faster startup
2. **Version models**: Use meaningful artifact names to easily identify the model in object store paths
3. **Shard appropriately**: Shard models for the GPU configuration you'll use for inference
4. **Cache prefetched models**: Once prefetched, models are cached in your storage backend for faster serving
## Troubleshooting
**Prefetch fails:**
- Check HuggingFace token (if required)
- Verify model repo exists and is accessible
- Check resource availability
- Review prefetch task logs
**Sharding fails:**
- Ensure accelerator matches shard config
- Check GPU memory is sufficient
- Verify `tensor_parallel_size` matches GPU count
- Review prefetch task logs for sharding-related errors
**Model not found in app:**
- Verify RunOutput references correct run name
- Check that prefetch completed successfully
- Ensure model_path is set correctly
- Review app startup logs
=== PAGE: https://www.union.ai/docs/v2/flyte/user-guide/considerations ===
# Considerations
Flyte 2 represents a substantial change from Flyte 1.
While the static graph execution model will soon be available and will mirror Flyte 1 almost exactly, the primary mode of execution in Flyte 2 should remain pure-Python-based.
That is, each Python-based task action has the ability to act as its own engine, kicking off sub-actions, and assembling the outputs, passing them to yet other sub-actions and such.
While this model of execution comes with an enormous amount of flexibility, that flexibility does warrant some caveats to keep in mind when authoring your tasks.
## Non-deterministic behavior
When a task launches another task, a new Action ID is determined.
This ID is a hash of the inputs to the task, the task definition itself, along with some other information.
The fact that this ID is consistently hashed is important when it comes to things like recovery and replay.
For example, assume you have the following tasks
```python
@env.task
async def t1():
val = get_int_input()
await t2(int=val)
@env.task
async def t2(val: int): ...
```
If you run `t1`, and it launches the downstream `t2` task, and then the pod executing `t1` fails, when Flyte restarts `t1` it will automatically detect that `t2` is still running and will just use that.
If `t2` ends up finishing in the interim, those results would just be used.
However, if you introduce non-determinism into the picture, then that guarantee is no longer there.
To give a contrived example:
```python
@env.task
async def t1():
val = get_int_input()
now = datetime.now()
if now.second % 2 == 0:
await t2(int=val)
else:
await t3(int=val)
```
Here, depending on what time it is, either `t2` or `t3` may end up running.
In the earlier scenario, if `t1` crashes unexpectedly, and Flyte retries the execution, a different downstream task may get kicked off instead.
### Dealing with non-determinism
As a developer, the best way to manage non-deterministic behavior (if it is unavoidable) is to be able to observe it and see exactly what is happening in your code. Flyte 2 provides precisely the tool needed to enable this: Traces.
With this feature you decorate the sub-task functions in your code with `@trace`, enabling checkpointing, reproducibility and recovery at a fine-grained level. See **Build tasks > Traces** for more details.
## Type safety
In Flyte 1, the top-level workflow was defined by a Python-like DSL that was compiled into a static DAG composed of tasks, each of which was, internally, defined in real Python.
The system was able to guarantee type safety across task boundaries because the task definitions were static and the inputs and outputs were defined in a way that Flytekit could validate them.
In Flyte 2, the top-level workflow is defined by Python code that runs at runtime (unless using a compiled task).
This means that the system can no longer guarantee type safety at the workflow level.
Happily, the Python ecosystem has evolved considerably since Flyte 1, and Python type hints are now a standard way to define types.
Consequently, in Flyte 2, developers should use Python type hints and type checkers like `mypy` to ensure type safety at all levels, including the top-most task (i.e., the "workflow" level).
## No global state
A core principle of Flyte 2 (that is also shared with Flyte 1) is that you should not try to maintain global state across your workflow.
It will not be translated across tasks containers,
In a single process Python program, global variables are available across functions.
In the distributed execution model of Flyte, each task runs in its own container, and each container is isolated from the others.
If there is some state that needs to be preserved, it must be reconstructable through repeated deterministic execution.
## Driver pod requirements
Tasks don't have to kick off downstream tasks of course and may themselves represent a leaf level atomic unit of compute.
However, when tasks do run other tasks, and more so if they assemble the outputs of those other tasks, then that parent task becomes a driver
pod of sorts.
In Flyte 1, this assembling of intermediate outputs was done by Flyte Propeller.
In 2, it's done by the parent task.
This means that the pod running your parent task must be appropriately sized, and should ideally not be CPU-bound, otherwise it slow down downstream evaluation and kickoff of tasks.
For example, if you had this also scenario,
```python
@env.task
async def t_main():
await t1()
local_cpu_intensive_function()
await t2()
```
The pod running `t_main` will hang in between tasks `t1` and `t2`. Your parent tasks should ideally focus only on orchestration.
## OOM risk from materialized I/O
Something maybe more nuanced to keep in mind is that if you're not using the soon-to-be-released ref mode, outputs are actually
materialized. That is, if you have the following scenario,
```python
@env.task
async def produce_1gb_list() -> List[float]: ...
@env.task
async def t1():
list_floats = produce_1gb_list()
t2(floats=list_floats)
```
The pod running `t1` needs to have memory to handle that 1 GB of floats. Those numbers will be materialized in that pod's memory.
This can lead to out of memory issues.
Note that `flyte.io.File`, `flyte.io.Dir` and `flyte.io.DataFrame` will not suffer from this because while those are materialized, they're only materialized as pointers to offloaded data, so their memory footprint is much lower.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials ===
# Tutorials
This section contains tutorials that showcase relevant use cases and provide step-by-step instructions on how to implement various features using Flyte and Union.
### π **Multi-agent trading simulation**
A multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.
### π **Run LLM-generated code**
Securely execute and iterate on LLM-generated code using a code agent with error reflection and retry logic.
### π **Deep research**
Build an agentic workflow for deep research with multi-step reasoning and evaluation.
### π **Hyperparameter optimization**
Run large-scale HPO experiments with zero manual tracking, deterministic results, and automatic recovery.
### π **Automatic prompt engineering**
Easily run prompt optimization with real-time observability, traceability, and automatic recovery.
### π **Text-to-SQL**
Learn how to turn natural language questions into SQL queries with Flyte and LlamaIndex, and explore prompt optimization in practice.
## Subpages
- **Automatic prompt engineering**
- **Deep research**
- **Hyperparameter optimization**
- **Multi-agent trading simulation**
- **Run LLM-generated code**
- **Text-to-SQL**
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/auto_prompt_engineering ===
# Automatic prompt engineering
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/auto_prompt_engineering).
When building with LLMs and agents, the first prompt almost never works. We usually need several iterations before results are useful. Doing this manually is slow, inconsistent, and hard to reproduce.
Flyte turns prompt engineering into a systematic process. With Flyte we can:
- Generate candidate prompts automatically.
- Run evaluations in parallel.
- Track results in real time with built-in observability.
- Recover from failures without losing progress.
- Trace the lineage of every experiment for reproducibility.
And we're not limited to prompts. Just like **Automatic prompt engineering > hyperparameter optimization** in ML, we can tune model temperature, retrieval strategies, tool usage, and more. Over time, this grows into full agentic evaluations, tracking not only prompts but also how agents behave, make decisions, and interact with their environment.
In this tutorial, we'll build an automated prompt engineering pipeline with Flyte, step by step.
## Set up the environment
First, let's configure our task environment.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
We need an API key to call GPT-4.1 (our optimization model). Add it as a Flyte secret:
```
flyte create secret openai_api_key
```
We also define CSS styles for live HTML reports that track prompt optimization in real time:

## Prepare the evaluation dataset
Next, we define our golden dataset, a set of prompts with known outputs. This dataset is used to evaluate the quality of generated prompts.
For this tutorial, we use a small geometric shapes dataset. To keep it portable, the data prep task takes a CSV file (as a Flyte `File` or a string for files available remotely) and splits it into train and test subsets.
If you already have prompts and outputs in Google Sheets, simply export them as CSV with two columns: `input` and `target`.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
This approach works with any dataset. You can swap in your own with no extra dependencies.
## Define models
We use two models:
- **Target model** β the one we want to optimize.
- **Review model** β the one that evaluates candidate prompts.
First, we capture all model parameters in a dataclass:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
Then we define a Flyte `trace` to call the model. Unlike a task, a trace runs within the same runtime as the parent process. Since the model is hosted externally, this keeps the call lightweight but still observable.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
Finally, we wrap the trace in a task to call both target and review models:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Evaluate prompts
We now define the evaluation process.
Each prompt in the dataset is tested in parallel, but we use a semaphore to control concurrency. A helper function ties together the `generate_and_review` task with an HTML report template. Using `asyncio.gather`, we evaluate multiple prompts at once.
The function measures accuracy as the fraction of responses that match the ground truth. Flyte streams these results to the UI, so you can watch evaluations happen live.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Optimize prompts
Optimization builds on evaluation. We give the optimizer model:
- the history of prompts tested so far, and
- their accuracies.
The model then proposes a new prompt.
We start with a _baseline_ evaluation using the user-provided prompt. Then for each iteration, the optimizer suggests a new prompt, which we evaluate and log. We continue until we hit the iteration limit.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
At the end, we return the best prompt and its accuracy. The report shows how accuracy improves over time and which prompts were tested.

## Build the full pipeline
The entrypoint task wires everything together:
- Accepts model configs, dataset, iteration count, and concurrency.
- Runs data preparation.
- Calls the optimizer.
- Evaluates both baseline and best prompts on the test set.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Run it
We add a simple main block so we can run the workflow as a script:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
Run it with:
```
uv run --prerelease=allow optimizer.py
```

## Why this matters
Most prompt engineering pipelines start as quick scripts or notebooks. They're fine for experimenting, but they're difficult to scale, reproduce, or debug when things go wrong.
With Flyte 2, we get a more reliable setup:
- Run many evaluations in parallel with [async Python](../../user-guide/flyte-2/async#true-parallelism-for-all-workloads) or [native DSL](../../user-guide/flyte-2/async#the-flytemap-function-familiar-patterns).
- Watch accuracy improve in real time and link results back to the exact dataset, prompt, and model config used.
- Resume cleanly after failures without rerunning everything from scratch.
- Reuse the same pattern to tune other parameters like temperature, retrieval depth, or agent strategies, not just prompts.
## Next steps
You now have a working automated prompt engineering pipeline. Hereβs how you can take it further:
- **Optimize beyond prompts**: Tune temperature, retrieval strategies, or tool usage just like prompts.
- **Expand evaluation metrics**: Add latency, cost, robustness, or diversity alongside accuracy.
- **Move toward agentic evaluation**: Instead of single prompts, test how agents plan, use tools, and recover from failures in long-horizon tasks.
With this foundation, prompt engineering becomes repeatable, observable, and scalable, ready for production-grade LLM and agent systems.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/deep-research ===
# Deep research
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/deep_research_agent); based on work by [Together AI](https://github.com/togethercomputer/open_deep_research).
This example demonstrates how to build an agentic workflow for deep researchβa multi-step reasoning system that mirrors how a human researcher explores, analyzes, and synthesizes information from the web.
Deep research refers to the iterative process of thoroughly investigating a topic: identifying relevant sources, evaluating their usefulness, refining the research direction, and ultimately producing a well-structured summary or report. It's a long-running task that requires the agent to reason over time, adapt its strategy, and chain multiple steps together, making it an ideal fit for an agentic architecture.
In this example, we use:
- [Tavily](https://www.tavily.com/) to search for and retrieve high-quality online resources.
- [LiteLLM](https://litellm.ai/) to route LLM calls that perform reasoning, evaluation, and synthesis.
The agent executes a multi-step trajectory:
- Parallel search across multiple queries.
- Evaluation of retrieved results.
- Adaptive iteration: If results are insufficient, it formulates new research queries and repeats the search-evaluate cycle.
- Synthesis: After a fixed number of iterations, it produces a comprehensive research report.
What makes this workflow compelling is its dynamic, evolving nature. The agent isn't just following a fixed plan; it's making decisions in context, using multiple prompts and reasoning steps to steer the process.
Flyte is uniquely well-suited for this kind of system. It provides:
- Structured composition of dynamic reasoning steps
- Built-in parallelism for faster search and evaluation
- Traceability and observability into each step and iteration
- Scalability for long-running or compute-intensive workloads

Throughout this guide, we'll show how to design this workflow using the Flyte SDK, and how to unlock the full potential of agentic development with tools you already know and trust.
## Setting up the environment
Let's begin by setting up the task environment. We define the following components:
- Secrets for Together and Tavily API keys
- A custom image with required Python packages and apt dependencies (`pandoc`, `texlive-xetex`)
- External YAML file with all LLM prompts baked into the container
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
The Python packages are declared at the top of the file using the `uv` script style:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b6",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# ///
```
## Generate research queries
This task converts a user prompt into a list of focused queries. It makes two LLM calls to generate a high-level research plan and parse that plan into atomic search queries.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
LLM calls use LiteLLM, and each is wrapped with `flyte.trace` for observability:
```
from typing import Any, AsyncIterator, Optional
from litellm import acompletion, completion
import flyte
# {{docs-fragment asingle_shot_llm_call}}
@flyte.trace
async def asingle_shot_llm_call(
model: str,
system_prompt: str,
message: str,
response_format: Optional[dict[str, str | dict[str, Any]]] = None,
max_completion_tokens: int | None = None,
) -> AsyncIterator[str]:
stream = await acompletion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
temperature=0.0,
response_format=response_format,
# NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
# NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
max_tokens=max_completion_tokens,
timeout=600,
stream=True,
)
async for chunk in stream:
content = chunk.choices[0].delta.get("content", "")
if content:
yield content
# {{/docs-fragment asingle_shot_llm_call}}
def single_shot_llm_call(
model: str,
system_prompt: str,
message: str,
response_format: Optional[dict[str, str | dict[str, Any]]] = None,
max_completion_tokens: int | None = None,
) -> str:
response = completion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
temperature=0.0,
response_format=response_format,
# NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
# NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
max_tokens=max_completion_tokens,
timeout=600,
)
return response.choices[0].message["content"] # type: ignore
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/libs/utils/llms.py)
> [!NOTE]
> We use `flyte.trace` to track intermediate steps within a task, like LLM calls or specific function executions. This lightweight decorator adds observability with minimal overhead and is especially useful for inspecting reasoning chains during task execution.
## Search and summarize
We submit each research query to Tavily and summarize the results using an LLM. We run all summarization tasks with `asyncio.gather`, which signals to Flyte that these tasks can be distributed across separate compute resources.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Evaluate research completeness
Now we assess whether the gathered research is sufficient. Again, the task uses two LLM calls to evaluate the completeness of the results and propose additional queries if necessary.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Filter results
In this step, we evaluate the relevance of search results and rank them. This task returns the most useful sources for the final synthesis.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Generate the final answer
Finally, we generate a detailed research report by synthesizing the top-ranked results. This is the output returned to the user.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Orchestration
Next, we define a `research_topic` task to orchestrate the entire deep research workflow. It runs the core stages in sequence: generating research queries, performing search and summarization, evaluating the completeness of results, and producing the final report.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
The `main` task wraps this entire pipeline and adds report generation in HTML format as the final step.
It also serves as the main entry point to the workflow, allowing us to pass in all configuration parameters, including which LLMs to use at each stage.
This flexibility lets us mix and match models for planning, summarization, and final synthesis, helping us optimize for both cost and quality.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Run the deep research agent
First, create the required secrets:
```
flyte create secret TOGETHER_API_KEY <>
flyte create secret TAVILY_API_KEY <>
```
Run the agent:
```
uv run --prerelease=allow agent.py
```
If you want to test it locally first, run the following commands:
```
brew install pandoc
brew install basictex # restart your terminal after install
export TOGETHER_API_KEY=<>
export TAVILY_API_KEY=<>
uv run --prerelease=allow agent.py
```
## Evaluate with Weights & Biases Weave
We use W&B Weave to evaluate the full agent pipeline and analyze LLM-generated responses. The evaluation runs as a Flyte pipeline and uses an LLM-as-a-judge scorer to measure the quality of LLM-generated responses.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "weave==0.51.51",
# "datasets==3.6.0",
# "huggingface-hub==0.32.6",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# ]
# ///
import os
import weave
from agent import research_topic
from datasets import load_dataset
from huggingface_hub import login
from libs.utils.log import AgentLogger
from litellm import completion
import flyte
logging = AgentLogger()
weave.init(project_name="deep-researcher")
env = flyte.TaskEnvironment(name="deep-researcher-eval")
@weave.op
def llm_as_a_judge_scoring(answer: str, output: str, question: str) -> bool:
prompt = f"""
Given the following question and answer, evaluate the answer against the correct answer:
{question}
{output}
{answer}
Note that the agent answer might be a long text containing a lot of information or it might be a short answer.
You should read the entire text and think if the agent answers the question somewhere
in the text. You should try to be flexible with the answer but careful.
For example, answering with names instead of name and surname is fine.
The important thing is that the answer of the agent either contains the correct answer or is equal to
the correct answer.
The agent answer is correct because I can read that ....
1
Otherwise, return
The agent answer is incorrect because there is ...
0
"""
messages = [
{
"role": "system",
"content": "You are an helpful assistant that returns a number between 0 and 1.",
},
{"role": "user", "content": prompt},
]
answer = (
completion(
model="together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
messages=messages,
max_tokens=1000,
temperature=0.0,
)
.choices[0] # type: ignore
.message["content"] # type: ignore
)
return bool(int(answer.split("")[1].split("")[0].strip()))
def authenticate_huggingface():
"""Authenticate with Hugging Face Hub using token from environment variable."""
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable not set. "
"Please set it with your token from https://huggingface.co/settings/tokens"
)
try:
login(token=token)
print("Successfully authenticated with Hugging Face Hub")
except Exception as e:
raise RuntimeError(f"Failed to authenticate with Hugging Face Hub: {e!s}")
@env.task
async def load_questions(
dataset_names: list[str] | None = None,
) -> list[dict[str, str]]:
"""
Load questions from the specified Hugging Face dataset configurations.
Args:
dataset_names: List of dataset configurations to load
Options:
"smolagents:simpleqa",
"hotpotqa",
"simpleqa",
"together-search-bench"
If None, all available configurations except hotpotqa will be loaded
Returns:
List of question-answer pairs
"""
if dataset_names is None:
dataset_names = ["smolagents:simpleqa"]
all_questions = []
# Authenticate with Hugging Face Hub (once and for all)
authenticate_huggingface()
for dataset_name in dataset_names:
print(f"Loading dataset: {dataset_name}")
try:
if dataset_name == "together-search-bench":
# Load Together-Search-Bench dataset
dataset_path = "togethercomputer/together-search-bench"
ds = load_dataset(dataset_path)
if "test" in ds:
split_data = ds["test"]
else:
print(f"No 'test' split found in dataset at {dataset_path}")
continue
for i in range(len(split_data)):
item = split_data[i]
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": item.get("dataset", "together-search-bench"),
}
all_questions.append(question_data)
print(f"Loaded {len(split_data)} questions from together-search-bench dataset")
continue
elif dataset_name == "hotpotqa":
# Load HotpotQA dataset (using distractor version for validation)
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True)
split_name = "validation"
elif dataset_name == "simpleqa":
ds = load_dataset("basicv8vc/SimpleQA")
split_name = "test"
else:
# Strip "smolagents:" prefix when loading the dataset
actual_dataset = dataset_name.split(":")[-1]
ds = load_dataset("smolagents/benchmark-v1", actual_dataset)
split_name = "test"
except Exception as e:
print(f"Failed to load dataset {dataset_name}: {e!s}")
continue # Skip this dataset if it fails to load
print(f"Dataset structure for {dataset_name}: {ds}")
print(f"Available splits: {list(ds)}")
split_data = ds[split_name] # type: ignore
for i in range(len(split_data)):
item = split_data[i]
if dataset_name == "hotpotqa":
# we remove questions that are easy or medium (if any) just to reduce the number of questions
if item["level"] != "hard":
continue
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": dataset_name,
}
elif dataset_name == "simpleqa":
# Handle SimpleQA dataset format
question_data = {
"question": item["problem"],
"answer": item["answer"],
"dataset": dataset_name,
}
else:
question_data = {
"question": item["question"],
"answer": item["true_answer"],
"dataset": dataset_name,
}
all_questions.append(question_data)
print(f"Loaded {len(all_questions)} questions in total")
return all_questions
@weave.op
async def predict(question: str):
return await research_topic(topic=str(question))
@env.task
async def main(datasets: list[str] = ["together-search-bench"], limit: int | None = 1):
questions = await load_questions(datasets)
if limit is not None:
questions = questions[:limit]
print(f"Limited to {len(questions)} question(s)")
evaluation = weave.Evaluation(dataset=questions, scorers=[llm_as_a_judge_scoring])
await evaluation.evaluate(predict)
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(raw_data_path="data").run(main)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/weave_evals.py)
You can run this pipeline locally as follows:
```
export HUGGINGFACE_TOKEN=<> # https://huggingface.co/settings/tokens
export WANDB_API_KEY=<> # https://wandb.ai/settings
uv run --prerelease=allow weave_evals.py
```
The script will run all tasks in the pipeline and log the evaluation results to Weights & Biases.
While you can also evaluate individual tasks, this script focuses on end-to-end evaluation of the end-to-end deep research workflow.

=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/hpo ===
# Hyperparameter optimization
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/ml/optimizer.py).
Hyperparameter Optimization (HPO) is a critical step in the machine learning (ML) lifecycle. Hyperparameters are the knobs and dials of a modelβvalues such as learning rates, tree depths, or dropout rates that significantly impact performance but cannot be learned during training. Instead, we must select them manually or optimize them through guided search.
Model developers often enjoy the flexibility of choosing from a wide variety of model types, whether gradient boosted machines (GBMs), generalized linear models (GLMs), deep learning architectures, or dozens of others. A common challenge across all these options is the need to systematically explore model performance across hyperparameter configurations tailored to the specific dataset and task.
Thankfully, this exploration can be automated. Frameworks like [Optuna](https://optuna.org/), [Hyperopt](https://hyperopt.github.io/hyperopt/), and [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) use advanced sampling algorithms to efficiently search the hyperparameter space and identify optimal configurations. HPO may be executed in two distinct ways:
- **Serial HPO** runs one trial at a time, which is easy to set up but can be painfully slow.
- **Parallel HPO** distributes trials across multiple processes. It typically follows a pattern with two parameters: **_N_**, the total number of trials to run, and **_C_**, the maximum number of trials that can run concurrently. Trials are executed asynchronously, and new ones are scheduled based on the results and status of completed or in-progress ones.
However, parallel HPO introduces a new complexity: the need for a centralized state that tracks:
- All past trials (successes and failures)
- All ongoing trials
This state is essential so that the optimization algorithm can make informed decisions about which hyperparameters to try next.
## A better way to run HPO
This is where Flyte shines.
- There's no need to manage a separate centralized database for state tracking, as every objective run is **cached**, **recorded**, and **recoverable** via Flyte's execution engine.
- The entire HPO process is observable in the UI with full lineage and metadata for each trial.
- Each objective is seeded for reproducibility, enabling deterministic trial results.
- If the main optimization task crashes or is terminated, **Flyte can resume from the last successful or failed trial, making the experiment highly fault-tolerant**.
- Trial functions can be strongly typed, enabling rich, flexible hyperparameter spaces while maintaining strict type safety across trials.
In this example, we combine Flyte with Optuna to optimize a `RandomForestClassifier` on the Iris dataset. Each trial runs in an isolated task, and the optimization process is orchestrated asynchronously, with Flyte handling the underlying scheduling, retries, and caching.
## Declare dependencies
We start by declaring a Python environment using Python 3.13 and specifying our runtime dependencies.
```
# /// script
requires-python = "==3.13"
dependencies = [
"optuna>=4.0.0,<5.0.0",
"flyte>=2.0.0b0",
"scikit-learn==1.7.0",
]
# ///
```
With the environment defined, we begin by importing standard library and third-party modules necessary for both the ML task and distributed execution.
```
import asyncio
import typing
from collections import Counter
from typing import Optional, Union
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
These standard library imports are essential for asynchronous execution (`asyncio`), type annotations (`typing`, `Optional`, `Union`), and aggregating trial state counts (`Counter`).
```
import optuna
from optuna import Trial
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We use Optuna for hyperparameter optimization and several utilities from scikit-learn to prepare data (`load_iris`), define the model (`RandomForestClassifier`), evaluate it (`cross_val_score`), and shuffle the dataset for randomness (`shuffle`).
```
import flyte
import flyte.errors
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
Flyte is our orchestration framework. We use it to define tasks, manage resources, and recover from execution errors.
## Define the task environment
We define a Flyte task environment called `driver`, which encapsulates metadata, compute resources, the container image context needed for remote execution, and caching behavior.
```
driver = flyte.TaskEnvironment(
name="driver",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="optimizer"),
cache="auto",
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This environment specifies that the tasks will run with 1 CPU and 250Mi of memory, the image is built using the current script (`__file__`), and caching is enabled.
## Define the optimizer
Next, we define an `Optimizer` class that handles parallel execution of Optuna trials using async coroutines. This class abstracts the full optimization loop and supports concurrent trial execution with live logging.
```
class Optimizer:
def __init__(
self,
objective: callable,
n_trials: int,
concurrency: int = 1,
delay: float = 0.1,
study: Optional[optuna.Study] = None,
log_delay: float = 0.1,
):
self.n_trials: int = n_trials
self.concurrency: int = concurrency
self.objective: typing.Callable = objective
self.delay: float = delay
self.log_delay = log_delay
self.study = study if study else optuna.create_study()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We pass the `objective` function, number of trials to run (`n_trials`), and maximum parallel trials (`concurrency`). The optional delay throttles execution between trials, while `log_delay` controls how often logging runs. If no existing Optuna Study is provided, a new one is created automatically.
```
async def log(self):
while True:
await asyncio.sleep(self.log_delay)
counter = Counter()
for trial in self.study.trials:
counter[trial.state.name.lower()] += 1
counts = dict(counter, queued=self.n_trials - len(self))
# print items in dictionary in a readable format
formatted = [f"{name}: {count}" for name, count in counts.items()]
print(f"{' '.join(formatted)}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This method periodically prints the number of trials in each state (e.g., running, complete, fail). It keeps users informed of ongoing optimization progress and is invoked as a background task when logging is enabled.

_Logs are streamed live as the execution progresses._
```
async def spawn(self, semaphore: asyncio.Semaphore):
async with semaphore:
trial: Trial = self.study.ask()
try:
print("Starting trial", trial.number)
params = {
"n_estimators": trial.suggest_int("n_estimators", 10, 200),
"max_depth": trial.suggest_int("max_depth", 2, 20),
"min_samples_split": trial.suggest_float(
"min_samples_split", 0.1, 1.0
),
}
output = await self.objective(params)
self.study.tell(trial, output, state=optuna.trial.TrialState.COMPLETE)
except flyte.errors.RuntimeUserError as e:
print(f"Trial {trial.number} failed: {e}")
self.study.tell(trial, state=optuna.trial.TrialState.FAIL)
await asyncio.sleep(self.delay)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
Each call to `spawn` runs a single Optuna trial. The `semaphore` ensures that only a fixed number of concurrent trials are active at once, respecting the `concurrency` parameter. We first ask Optuna for a new trial and generate a parameter dictionary by querying the trial object for suggested hyperparameters. The trial is then evaluated by the objective function. If successful, we mark it as `COMPLETE`. If the trial fails due to a `RuntimeUserError` from Flyte, we log and record the failure in the Optuna study.
```
async def __call__(self):
# create semaphore to manage concurrency
semaphore = asyncio.Semaphore(self.concurrency)
# create list of async trials
trials = [self.spawn(semaphore) for _ in range(self.n_trials)]
logger: Optional[asyncio.Task] = None
if self.log_delay:
logger = asyncio.create_task(self.log())
# await all trials to complete
await asyncio.gather(*trials)
if self.log_delay and logger:
logger.cancel()
try:
await logger
except asyncio.CancelledError:
pass
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
The `__call__` method defines the overall async optimization routine. It creates the semaphore, spawns `n_trials` coroutines, and optionally starts the background logging task. All trials are awaited with `asyncio.gather`.
```
def __len__(self) -> int:
"""Return the number of trials in history."""
return len(self.study.trials)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This method simply allows us to query the number of trials already associated with the study.
## Define the objective function
The objective task defines how we evaluate a particular set of hyperparameters. It's an async task, allowing for caching, tracking, and recoverability across executions.
```
@driver.task
async def objective(params: dict[str, Union[int, float]]) -> float:
data = load_iris()
X, y = shuffle(data.data, data.target, random_state=42)
clf = RandomForestClassifier(
n_estimators=params["n_estimators"],
max_depth=params["max_depth"],
min_samples_split=params["min_samples_split"],
random_state=42,
n_jobs=-1,
)
# Use cross-validation to evaluate performance
score = cross_val_score(clf, X, y, cv=3, scoring="accuracy").mean()
return score.item()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We use the Iris dataset as a toy classification problem. The input params dictionary contains the trial's hyperparameters, which we unpack into a `RandomForestClassifier`. We shuffle the dataset for randomness, and compute a 3-fold cross-validation accuracy.
## Define the main optimization loop
The optimize task is the main driver of our optimization experiment. It creates the `Optimizer` instance and invokes it.
```
@driver.task
async def optimize(
n_trials: int = 20,
concurrency: int = 5,
delay: float = 0.05,
log_delay: float = 0.1,
) -> dict[str, Union[int, float]]:
optimizer = Optimizer(
objective=objective,
n_trials=n_trials,
concurrency=concurrency,
delay=delay,
log_delay=log_delay,
study=optuna.create_study(
direction="maximize", sampler=optuna.samplers.TPESampler(seed=42)
),
)
await optimizer()
best = optimizer.study.best_trial
print("β Best Trial")
print(" Number :", best.number)
print(" Params :", best.params)
print(" Score :", best.value)
return best.params
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We configure a `TPESampler` for Optuna and `seed` it for determinism. After running all trials, we extract the best-performing trial and print its parameters and score. Returning the best params allows downstream tasks or clients to use the tuned model.
## Run the experiment
Finally, we include an executable entry point to run this optimization using `flyte.run`.
```
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(optimize, 100, 10)
print(run.url)
run.wait()
# {{//docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We load Flyte config from `config.yaml`, launch the optimize task with 100 trials and concurrency of 10, and print a link to view the execution in the Flyte UI.

_Each objective run is cached, recorded, and recoverable. With concurrency set to 10, only 10 trials execute in parallel at any given time._
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/trading-agents ===
# Multi-agent trading simulation
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/trading_agents); based on work by [TauricResearch](https://github.com/TauricResearch/TradingAgents).
This example walks you through building a multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

_Trading agents execution visualization_
## TL;DR
- You'll build a trading firm made up of agents that analyze, argue, and act, modeled with Python functions.
- You'll use the Flyte SDK to orchestrate this world β giving you visibility, retries, caching, and durability.
- You'll learn how to plug in tools, structure conversations, and track decisions across agents.
- You'll see how agents debate, use context, generate reports, and retain memory via vector DBs.
## What is an agent, anyway?
Agentic workflows are a rising pattern for complex problem-solving with LLMs. Think of agents as:
- An LLM (like GPT-4 or Mistral)
- A loop that keeps them thinking until a goal is met
- A set of optional tools they can call (APIs, search, calculators, etc.)
- Enough tokens to reason about the problem at hand
That's it.
You define tools, bind them to an agent, and let it run, reasoning step-by-step, optionally using those tools, until it finishes.
## What's different here?
We're not building yet another agent framework. You're free to use LangChain, custom code, or whatever setup you like.
What we're giving you is the missing piece: a way to run these workflows **reliably, observably, and at scale, with zero rewrites.**
With Flyte, you get:
- Prompt + tool traceability and full state retention
- Built-in retries, caching, and failure recovery
- A native way to plug in your agents; no magic syntax required
## How it works: step-by-step walkthrough
This simulation is powered by a Flyte task that orchestrates multiple intelligent agents working together to analyze a company's stock and make informed trading decisions.

_Trading agents schema_
### Entry point
Everything begins with a top-level Flyte task called `main`, which serves as the entry point to the workflow.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
This task accepts several inputs:
- the list of analysts to run,
- the number of debate and risk discussion rounds,
- a flag to enable online tools,
- the company you're evaluating,
- and the target trading date.
The most interesting parameter here is the list of analysts to run. It determines which analyst agents will be invoked and shapes the overall structure of the simulation. Based on this input, the task dynamically launches agent tasks, running them in parallel.
The `main` task is written as a regular asynchronous Python function wrapped with Flyte's task decorator. No domain-specific language or orchestration glue is needed β just idiomatic Python, optionally using async for better performance. The task environment is configured once and shared across all tasks for consistency.
```
# {{docs-fragment env}}
import flyte
QUICK_THINKING_LLM = "gpt-4o-mini"
DEEP_THINKING_LLM = "o4-mini"
env = flyte.TaskEnvironment(
name="trading-agents",
secrets=[
flyte.Secret(key="finnhub_api_key", as_env_var="FINNHUB_API_KEY"),
flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY"),
],
image=flyte.Image.from_uv_script("main.py", name="trading-agents", pre=True),
resources=flyte.Resources(cpu="1"),
cache="auto",
)
# {{/docs-fragment env}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/flyte_env.py)
### Analyst agents
Each analyst agent comes equipped with a set of tools and a carefully designed prompt tailored to its specific domain. These tools are modular Flyte tasks β for example, downloading financial reports or computing technical indicators β and benefit from Flyte's built-in caching to avoid redundant computation.
```
from datetime import datetime
import pandas as pd
import tools.interface as interface
import yfinance as yf
from flyte_env import env
from flyte.io import File
@env.task
async def get_reddit_news(
curr_date: str, # Date you want to get news for in yyyy-mm-dd format
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest global news
from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@env.task
async def get_finnhub_news(
ticker: str, # Search query of a company, e.g. 'AAPL, TSM, etc.
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock from Finnhub within a date range
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing news about the company
within the date range from start_date to end_date
"""
end_date_str = end_date
end_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
)
return finnhub_news_result
@env.task
async def get_reddit_stock_info(
ticker: str, # Ticker of a company. e.g. AAPL, TSM
curr_date: str, # Current date you want to get news for
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the company on the given date
"""
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
@env.task
async def get_YFin_data(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data(symbol, start_date, end_date)
return result_data
@env.task
async def get_YFin_data_online(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
return result_data
@env.task
async def cache_market_data(symbol: str, start_date: str, end_date: str) -> File:
data_file = f"{symbol}-YFin-data-{start_date}-{end_date}.csv"
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
return await File.from_local(data_file)
@env.task
async def get_stockstats_indicators_report(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd
look_back_days: int = 30, # how many days to look back
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False, local_data_file
)
return result_stockstats
# {{docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_stockstats_indicators_report_online(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd"
look_back_days: int = 30, # "how many days to look back"
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True, local_data_file
)
return result_stockstats
# {{/docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_finnhub_company_insider_sentiment(
ticker: str, # ticker symbol for the company
curr_date: str, # current date of you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider sentiment information about a company (retrieved
from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 30 days starting at curr_date
"""
data_sentiment = interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30
)
return data_sentiment
@env.task
async def get_finnhub_company_insider_transactions(
ticker: str, # ticker symbol
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider transaction information about a company
(retrieved from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transactions/trading information in the past 30 days
"""
data_trans = interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30
)
return data_trans
@env.task
async def get_simfin_balance_sheet(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
):
"""
Retrieve the most recent balance sheet of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent balance sheet
"""
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return data_balance_sheet
@env.task
async def get_simfin_cashflow(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent cash flow statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent cash flow statement
"""
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
return data_cashflow
@env.task
async def get_simfin_income_stmt(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent income statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent income statement
"""
data_income_stmt = interface.get_simfin_income_statements(ticker, freq, curr_date)
return data_income_stmt
@env.task
async def get_google_news(
query: str, # Query to search with
curr_date: str, # Curr date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news from Google News based on a query and date range.
Args:
query (str): Query to search with
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): How many days to look back
Returns:
str: A formatted string containing the latest news from Google News
based on the query and date range.
"""
google_news_results = interface.get_google_news(query, curr_date, 7)
return google_news_results
@env.task
async def get_stock_news_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest news about the company on the given date.
"""
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@env.task
async def get_global_news_openai(
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
Args:
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest macroeconomic news on the given date.
"""
openai_news_results = interface.get_global_news_openai(curr_date)
return openai_news_results
@env.task
async def get_fundamentals_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest fundamental information about a given stock
on a given date by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest fundamental information
about the company on the given date.
"""
openai_fundamentals_results = interface.get_fundamentals_openai(ticker, curr_date)
return openai_fundamentals_results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/tools/toolkit.py)
When initialized, an analyst enters a structured reasoning loop (via LangChain), where it can call tools, observe outputs, and refine its internal state before generating a final report. These reports are later consumed by downstream agents.
Here's an example of a news analyst that interprets global events and macroeconomic signals. We specify the tools accessible to the analyst, and the LLM selects which ones to use based on context.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py)
Each analyst agent uses a helper function to bind tools, iterate through reasoning steps (up to a configurable maximum), and produce an answer. Setting a max iteration count is crucial to prevent runaway loops. As agents reason, their message history is preserved in their internal state and passed along to the next agent in the chain.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py)
Once all analyst reports are complete, their outputs are collected and passed to the next stage of the workflow.
### Research agents
The research phase consists of two agents: a bullish researcher and a bearish one. They evaluate the company from opposing viewpoints, drawing on the analysts' reports. Unlike analysts, they don't use tools. Their role is to interpret, critique, and develop positions based on the evidence.
```
from agents.utils.utils import AgentState, InvestmentDebateState, memory_init
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment bear_researcher}}
@env.task
async def create_bear_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bear_history = investment_debate_state.bear_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bear-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock.
Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators.
Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability,
or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation,
or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning,
exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points
and debating effectively rather than simply listing facts.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate
that demonstrates the risks and weaknesses of investing in the stock.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bear_history=bear_history + "\n" + argument,
bull_history=investment_debate_state.bull_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
# {{/docs-fragment bear_researcher}}
@env.task
async def create_bull_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bull_history = investment_debate_state.bull_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bull-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock.
Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages,
and positive market indicators.
Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing
concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points
and debating effectively rather than just listing data.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate
that demonstrates the strengths of the bull position.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bull_history=bull_history + "\n" + argument,
bear_history=investment_debate_state.bear_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/researchers.py)
To aid reasoning, the agents can also retrieve relevant "memories" from a vector database, giving them richer historical context. The number of debate rounds is configurable, and after a few iterations of back-and-forth between the bull and bear, a research manager agent reviews their arguments and makes a final investment decision.
```
from agents.utils.utils import (
AgentState,
InvestmentDebateState,
RiskDebateState,
memory_init,
)
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment research_manager}}
@env.task
async def create_research_manager(llm: str, state: AgentState) -> AgentState:
history = state.investment_debate_state.history
investment_debate_state = state.investment_debate_state
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="research-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate
this round of debate and make a definitive decision:
align with the bear analyst, the bull analyst,
or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning.
Your recommendationβBuy, Sell, or Holdβmust be clear and actionable.
Avoid defaulting to Hold simply because both sides have valid points;
commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations.
Use these insights to refine your decision-making and ensure you are learning and improving.
Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
Here is the debate:
Debate History:
{history}"""
response = ChatOpenAI(model=llm).invoke(prompt)
new_investment_debate_state = InvestmentDebateState(
judge_decision=response.content,
history=investment_debate_state.history,
bear_history=investment_debate_state.bear_history,
bull_history=investment_debate_state.bull_history,
current_response=response.content,
count=investment_debate_state.count,
)
state.investment_debate_state = new_investment_debate_state
state.investment_plan = response.content
return state
# {{/docs-fragment research_manager}}
@env.task
async def create_risk_manager(llm: str, state: AgentState) -> AgentState:
history = state.risk_debate_state.history
risk_debate_state = state.risk_debate_state
trader_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="risk-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator,
your goal is to evaluate the debate between three risk analystsβRisky,
Neutral, and Safe/Conservativeβand determine the best course of action for the trader.
Your decision must result in a clear recommendation: Buy, Sell, or Hold.
Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid.
Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**,
and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments
and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
{history}
---
Focus on actionable insights and continuous improvement.
Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = ChatOpenAI(model=llm).invoke(prompt)
new_risk_debate_state = RiskDebateState(
judge_decision=response.content,
history=risk_debate_state.history,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Judge",
current_risky_response=risk_debate_state.current_risky_response,
current_safe_response=risk_debate_state.current_safe_response,
current_neutral_response=risk_debate_state.current_neutral_response,
count=risk_debate_state.count,
)
state.risk_debate_state = new_risk_debate_state
state.final_trade_decision = response.content
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/managers.py)
### Trading agent
The trader agent consolidates the insights from analysts and researchers to generate a final recommendation. It synthesizes competing signals and produces a conclusion such as _Buy for long-term growth despite short-term volatility_.
```
from agents.utils.utils import AgentState, memory_init
from flyte_env import env
from langchain_core.messages import convert_to_openai_messages
from langchain_openai import ChatOpenAI
# {{docs-fragment trader}}
@env.task
async def create_trader(llm: str, state: AgentState) -> AgentState:
company_name = state.company_of_interest
investment_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="trader")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, "
f"here is an investment plan tailored for {company_name}. "
"This plan incorporates insights from current technical market trends, "
"macroeconomic indicators, and social media sentiment. "
"Use this plan as a foundation for evaluating your next trading decision.\n\n"
f"Proposed Investment Plan: {investment_plan}\n\n"
"Leverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""You are a trading agent analyzing market data to make investment decisions.
Based on your analysis, provide a specific recommendation to buy, sell, or hold.
End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
to confirm your recommendation.
Do not forget to utilize lessons from past decisions to learn from your mistakes.
Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
},
context,
]
result = ChatOpenAI(model=llm).invoke(messages)
state.messages.append(convert_to_openai_messages(result))
state.trader_investment_plan = result.content
state.sender = "Trader"
return state
# {{/docs-fragment trader}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/trader.py)
### Risk agents
Risk agents comprise agents with different risk tolerances: a risky debater, a neutral one, and a conservative one. They assess the portfolio through lenses like market volatility, liquidity, and systemic risk. Similar to the bull-bear debate, these agents engage in internal discussion, after which a risk manager makes the final call.
```
from agents.utils.utils import AgentState, RiskDebateState
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment risk_debator}}
@env.task
async def create_risky_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
risky_history = risk_debate_state.risky_history
current_safe_response = risk_debate_state.current_safe_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities,
emphasizing bold strategies and competitive advantages.
When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential,
and innovative benefitsβeven when these come with elevated risk.
Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views.
Specifically, respond directly to each point made by the conservative and neutral analysts,
countering with data-driven rebuttals and persuasive reasoning.
Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative.
Here is the trader's decision:
{trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative
and neutral stances to demonstrate why your high-reward perspective offers the best path forward.
Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here are the last arguments from the conservative analyst: {current_safe_response}
Here are the last arguments from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic,
and asserting the benefits of risk-taking to outpace market norms.
Maintain a focus on debating and persuading, not just presenting data.
Challenge each counterpoint to underscore why a high-risk approach is optimal.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risky_history + "\n" + argument,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Risky",
current_risky_response=argument,
current_safe_response=current_safe_response,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
# {{/docs-fragment risk_debator}}
@env.task
async def create_safe_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
safe_history = risk_debate_state.safe_history
current_risky_response = risk_debate_state.current_risky_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets,
minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation,
carefully assessing potential losses, economic downturns, and market volatility.
When evaluating the trader's decision or plan, critically examine high-risk elements,
pointing out where the decision may expose the firm to undue risk and where more cautious
alternatives could secure long-term gains.
Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts,
highlighting where their views may overlook potential threats or fail to prioritize sustainability.
Respond directly to their points, drawing from the following data sources
to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked.
Address each of their counterpoints to showcase why a conservative stance is ultimately the
safest path for the firm's assets.
Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy
over their approaches.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=safe_history + "\n" + argument,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Safe",
current_risky_response=current_risky_response,
current_safe_response=argument,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
@env.task
async def create_neutral_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
neutral_history = risk_debate_state.neutral_history
current_risky_response = risk_debate_state.current_risky_response
current_safe_response = risk_debate_state.current_safe_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective,
weighing both the potential benefits and risks of the trader's decision or plan.
You prioritize a well-rounded approach, evaluating the upsides
and downsides while factoring in broader market trends,
potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts,
pointing out where each perspective may be overly optimistic or overly cautious.
Use insights from the following data sources to support a moderate, sustainable strategy
to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the safe analyst: {current_safe_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky
and conservative arguments to advocate for a more balanced approach.
Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds,
providing growth potential while safeguarding against extreme volatility.
Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to
the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=neutral_history + "\n" + argument,
latest_speaker="Neutral",
current_risky_response=current_risky_response,
current_safe_response=current_safe_response,
current_neutral_response=argument,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/risk_debators.py)
The outcome of the risk manager β whether to proceed with the trade or not β is considered the final decision of the trading simulation.
You can visualize this full pipeline in the Flyte/Union UI, where every step is logged.
Youβll see input/output metadata for each tool and agent task.
Thanks to Flyte's caching, repeated steps are skipped unless inputs change, saving time and compute resources.
### Retaining agent memory with S3 vectors
To help agents learn from past decisions, we persist their memory in a vector store. In this example, we use an [S3 vector](https://aws.amazon.com/s3/features/vectors/) bucket for their simplicity and tight integration with Flyte and Union, but any vector database can be used.
Note: To use the S3 vector store, make sure your IAM role has the following permissions configured:
```
s3vectors:CreateVectorBucket
s3vectors:CreateIndex
s3vectors:PutVectors
s3vectors:GetIndex
s3vectors:GetVectors
s3vectors:QueryVectors
s3vectors:GetVectorBucket
```
After each trade decision, you can run a `reflect_on_decisions` task. This evaluates whether the final outcome aligned with the agent's recommendation and stores that reflection in the vector store. These stored insights can later be retrieved to provide historical context and improve future decision-making.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
### Running the simulation
First, set up your OpenAI secret (from [openai.com](https://platform.openai.com/api-keys)) and Finnhub API key (from [finnhub.io](https://finnhub.io/)):
```
flyte create secret openai_api_key
flyte create secret finnhub_api_key
```
Then [clone the repo](https://github.com/unionai/unionai-examples), navigate to the `tutorials-v2/trading_agents` directory, and run the following commands:
```
flyte create config --endpoint --project --domain --builder remote
uv run main.py
```
If you'd like to run the `reflect_on_decisions` task instead, comment out the `main` function call and uncomment the `reflect_on_decisions` call in the `__main__` block:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
Then run:
```
uv run --prerelease=allow main.py
```
## Why Flyte? _(A quick note before you go)_
You might now be wondering: can't I just build all this with Python and LangChain?
Absolutely. But as your project grows, you'll likely run into these challenges:
1. **Observability**: Agent workflows can feel opaque. You send a prompt, get a response, but what happened in between?
- Were the right tools used?
- Were correct arguments passed?
- How did the LLM reason through intermediate steps?
- Why did it fail?
Flyte gives you a window into each of these stages.
2. **Multi-agent coordination**: Real-world applications often require multiple agents with distinct roles and responsibilities. In such cases, you'll need:
- Isolated state per agent,
- Shared context where needed,
- And coordination β sequential or parallel.
Managing this manually gets fragile, fast. Flyte handles it for you.
3. **Scalability**: Agents and tools might need to run in isolated or containerized environments. Whether you're scaling out to more agents or more powerful hardware, Flyte lets you scale without taxing your local machine or racking up unnecessary cloud bills.
4. **Durability & recovery**: LLM-based workflows are often long-running and expensive. If something fails halfway:
- Do you lose all progress?
- Replay everything from scratch?
With Flyte, you get built-in caching, checkpointing, and recovery, so you can resume where you left off.
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/code-agent ===
# Run LLM-generated code
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/code_runner).
This example demonstrates how to run code generated by a large language model (LLM) using a `ContainerTask`.
The agent takes a userβs question, generates Flyte 2 code using the Flyte 2 documentation as context, and runs it in an isolated container.
If the execution fails, the agent reflects on the error and retries
up to a configurable limit until it succeeds.
Using `ContainerTask` ensures that all generated code runs in a secure environment.
This gives you full flexibility to execute arbitrary logic safely and reliably.
## What this example demonstrates
- How to combine LLM generation with programmatic execution.
- How to run untrusted or dynamically generated code securely.
- How to iteratively improve code using agent-like behavior.
## Setting up the agent environment
Let's start by importing the necessary libraries and setting up two environments: one for the container task and another for the agent task.
This example follows the `uv` script format to declare dependencies.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b23",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# ///
```
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
> [!NOTE]
> You can set up access to the OpenAI API using a Flyte secret.
>
> ```
> flyte create secret openai_api_key
> ```
We store the LLM-generated code in a structured format. This allows us to:
- Enforce consistent formatting
- Make debugging easier
- Log and analyze generations systematically
By capturing metadata alongside the raw code, we maintain transparency and make it easier to iterate or trace issues over time.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
We then define a state model to persist the agent's history across iterations. This includes previous messages,
generated code, and any errors encountered.
Maintaining this history allows the agent to reflect on past attempts, avoid repeating mistakes,
and iteratively improve the generated code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Retrieve docs
We define a task to load documents from a given URL and concatenate them into a single string.
This string is then used as part of the LLM prompt.
We set `max_depth = 20` to avoid loading an excessive number of documents.
However, even with this limit, the resulting context can still be quite large.
To handle this, we use an LLM (GPT-4 in this case) that supports extended context windows.
> [!NOTE]
> Appending all documents into a single string can result in extremely large contexts, potentially exceeding the LLMβs token limit.
> If your dataset grows beyond what a single prompt can handle, there are a couple of strategies you can use.
> One option is to apply Retrieval-Augmented Generation (RAG), where you chunk the documents, embed them using a model,
> store the vectors in a vector database, and retrieve only the most relevant pieces at inference time.
>
> An alternative approach is to pass references to full files into the prompt, allowing the LLM to decide which files are most relevant based
> on natural-language search over file paths, summaries, or even contents. This method assumes that only a subset of files
> will be necessary for a given task, and the LLM is responsible for navigating the structure and identifying what to read.
> While this can be a lighter-weight solution for smaller datasets, its effectiveness depends on how well the LLM can
> reason over file references and the reliability of its internal search heuristics.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Code generation
Next, we define a utility function to construct the LLM chain responsible for generating Python code from user input. This chain leverages
a LangChain `PromptTemplate` to structure the input and an OpenAI chat model to generate well-formed, Flyte 2-compatible Python scripts.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
We then define a `generate` task responsible for producing the code solution.
To improve clarity and testability, the output is structured in three parts:
a short summary of the generated solution, a list of necessary imports,
and the main body of executable code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
A `ContainerTask` then executes this code in an isolated container environment.
It takes the code as input, runs it safely, and returns the programβs output and exit code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
This task verifies that the generated code runs as expected.
It tests the import statements first, then executes the full code.
It records the output and any error messages in the agent state for further analysis.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
If an error occurs, a separate task reflects on the failure and generates a response.
This reflection is added to the agent state to guide future iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
Finally, we define a `main` task that runs the code agent and orchestrates the steps above.
If the code execution fails, we reflect on the error and retry until we reach the maximum number of iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Running the code agent
If things are working properly, you should see output similar to the following:
```
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: PASSED---
---NO CODE TEST FAILURES---
---DECISION: FINISH---
In this solution, we define two tasks using Flyte v2.
The first task, `oomer`, is designed to simulate an out-of-memory (OOM) error by attempting to allocate a large list.
The second task, `failure_recovery`, attempts to execute `oomer` and catches any OOM errors.
If an OOM error is caught, it retries the `oomer` task with increased memory resources.
This pattern demonstrates how to handle resource-related exceptions and dynamically adjust task configurations in Flyte workflows.
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="oom_example", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000 # Simulate OOM
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
...
```
You can run the code agent on a Flyte/Union cluster using the following command:
```
uv run --prerelease=allow agent.py
```
=== PAGE: https://www.union.ai/docs/v2/flyte/tutorials/text_to_sql ===
# Text-to-SQL
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/text_to_sql); based on work by [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/workflow/advanced_text_to_sql/).
Data analytics drives modern decision-making, but SQL often creates a bottleneck. Writing queries requires technical expertise, so non-technical stakeholders must rely on data teams. That translation layer slows everyone down.
Text-to-SQL narrows this gap by turning natural language into executable SQL queries. It lowers the barrier to structured data and makes databases accessible to more people.
In this tutorial, we build a Text-to-SQL workflow using LlamaIndex and evaluate it on the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (a benchmark of natural language questions over semi-structured tables). We then explore prompt optimization to see whether it improves accuracy and show how to track prompts and results over time. Along the way, we'll see what worked, what didn't, and what we learned about building durable evaluation pipelines. The pattern here can be adapted to your own datasets and workflows.

## Ingesting data
We start by ingesting the WikiTableQuestions dataset, which comes as CSV files, into a SQLite database. This database serves as the source of truth for our Text-to-SQL pipeline.
```
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py)
The ingestion step:
1. Downloads the dataset (a zip archive from GitHub).
2. Extracts the CSV files locally.
3. Generates table metadata (names and descriptions).
4. Creates corresponding tables in SQLite.
The Flyte task returns both the path to the database and the generated table metadata.
```
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py)
## From question to SQL
Next, we define a workflow that converts natural language into executable SQL using a retrieval-augmented generation (RAG) approach.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The main `text_to_sql` task orchestrates the pipeline:
- Ingest data
- Build vector indices for each table
- Retrieve relevant tables and rows
- Generate SQL queries with an LLM
- Execute queries and synthesize answers
We use OpenAI GPT models with carefully structured prompts to maximize SQL correctness.
### Vector indexing
We index each table's rows semantically so the model can retrieve relevant examples during SQL generation.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
Each row becomes a text node stored in LlamaIndexβs `VectorStoreIndex`. This lets the system pull semantically similar rows when handling queries.
### Table retrieval and context building
We then retrieve the most relevant tables for a given query and build rich context that combines schema information with sample rows.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The retriever selects tables via semantic similarity, then attaches their schema and example rows. This context grounds the model's SQL generation in the database's actual structure and content.
### SQL generation and response synthesis
Finally, we generate SQL queries and produce natural language answers.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The SQL generation prompt includes schema, example rows, and formatting rules. After execution, the system returns a final answer.
At this point, we have an end-to-end Text-to-SQL pipeline: natural language questions go in, SQL queries run, and answers come back. To make this workflow production-ready, we leveraged several Flyte 2 capabilities. Caching ensures that repeated steps, like table ingestion or vector indexing, donβt need to rerun unnecessarily, saving time and compute. Containerization provides consistent, reproducible execution across environments, making it easier to scale and deploy. Observability features let us track every step of the pipeline, monitor performance, and debug issues quickly.
While the pipeline works end-to-end, to get a pulse on how it performs across multiple prompts and to gradually improve performance, we can start experimenting with prompt tuning.
Two things help make this process meaningful:
- **A clean evaluation dataset** - so we can measure accuracy against trusted ground truth.
- **A systematic evaluation loop** - so we can see whether prompt changes or other adjustments actually help.
With these in place, the next step is to build a "golden" QA dataset that will guide iterative prompt optimization.
## Building the QA dataset
> [!NOTE]
> The WikiTableQuestions dataset already includes questionβanswer pairs, available in its [GitHub repository](https://github.com/ppasupat/WikiTableQuestions/tree/master/data). To use them for this workflow, you'll need to adapt the data into the required format, but the raw material is there for you to build on.
We generate a dataset of natural language questions paired with executable SQL queries. This dataset acts as the benchmark for prompt tuning and evaluation.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
The pipeline does the following:
- Schema extraction β pull full database schemas, including table names, columns, and sample rows.
- QuestionβSQL generation β use an LLM to produce natural language questions with matching SQL queries.
- Validation β run each query against the database, filter out invalid results, and also remove results that aren't relevant.
- Final export β store the clean, validated pairs in CSV format for downstream use.
### Schema extraction and chunking
We break schemas into smaller chunks to cover all tables evenly. This avoids overfitting to a subset of tables and ensures broad coverage across the dataset.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
### Question and SQL generation
Using structured prompts, we ask an LLM to generate realistic questions users might ask, then pair them with syntactically valid SQL queries. Deduplication ensures diversity across queries.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
### Validation and quality control
Each generated SQL query runs against the database, and another LLM double-checks that the result matches the intent of the natural language question.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
Even with automated checks, human review remains critical. Since this dataset serves as the ground truth, mislabeled pairs can distort evaluation. For production use, always invest in human-in-the-loop review.
## Optimizing prompts
With the QA dataset in place, we can turn to prompt optimization. The idea: start from a baseline prompt, generate new variants, and measure whether accuracy improves.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
### Evaluation pipeline
We evaluate each prompt variant against the golden dataset, split into validation and test sets, and record accuracy metrics in real time.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
Here's how prompt accuracy evolves over time, as shown in the UI report:

### Iterative optimization
An optimizer LLM proposes new prompts by analyzing patterns in successful and failed generations. Each candidate runs through the evaluation loop, and we select the best performer.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
On paper, this creates a continuous improvement cycle: baseline β new variants β measured gains.
## Run it
To create the QA dataset:
```
python create_qa_dataset.py
```
To run the prompt optimization loop:
```
python optimizer.py
```
## What we observed
Prompt optimization didn't consistently lift SQL accuracy in this workflow. Accuracy plateaued near the baseline. But the process surfaced valuable lessons about what matters when building LLM-powered systems on real infrastructure.
- **Schema clarity matters**: CSV ingestion produced tables with overlapping names, creating ambiguity. This showed how schema design and metadata hygiene directly affect downstream evaluation.
- **Ground truth needs trust**: Because the dataset came from LLM outputs, noise remained even after filtering. Human review proved essential. Golden datasets need deliberate curation, not just automation.
- **Optimization needs context**: The optimizer couldn't βseeβ which examples failed, limiting its ability to improve. Feeding failures directly risks overfitting. A structured way to capture and reuse evaluation signals is the right long-term path.
Sometimes prompt tweaks alone can lift accuracy, but other times the real bottleneck lives in the data, the schema, or the evaluation loop. The lesson isn't "prompt optimization doesn't work", but that its impact depends on the system around it. Accuracy improves most reliably when prompts evolve alongside clean data, trusted evaluation, and observable feedback loops.
## The bigger lesson
Evaluation and optimization arenβt one-off experiments; theyβre continuous processes. What makes them sustainable isn't a clever prompt, itβs the platform around it.
Systems succeed when they:
- **Observe** failures with clarity β track exactly what failed and why.
- **Remain durable** across iterations β run pipelines that are stable, reproducible, and comparable over time.
That's where Flyte 2 comes in. Prompt optimization is one lever, but it becomes powerful only when combined with:
- Clean, human-validated evaluation datasets.
- Systematic reporting and feedback loops.
**The real takeaway: improving LLM pipelines isn't about chasing the perfect prompt. It's about designing workflows with observability and durability at the core, so that every experiment compounds into long-term progress.**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations ===
# Integrations
Flyte is designed to be highly extensible and can be customized
in multiple ways.
## Flyte Plugins
Flyte plugins extend the functionality of the `flyte` SDK.
| Plugin | Description |
| ------ | ----------- |
| **Flyte plugins > Ray** | Run Ray jobs on your Flyte cluster |
| **Flyte plugins > Spark** | Run Spark jobs on your Flyte cluster |
| **Flyte plugins > OpenAI** | Integrate with OpenAI SDKs in your Flyte workflows |
| **Flyte plugins > Dask** | Run Dask jobs on your Flyte cluster |
## Subpages
- **Connectors**
- **Flyte plugins**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/connectors ===
# Connectors
Connectors are stateless, longβrunning services that receive execution requests via gRPC and then submit work to external (or internal) systems. Each connector runs as its own Kubernetes deployment, and is triggered when a Flyte task of the matching type is executed. For example: when a `BigQueryTask` is launched, the BigQuery connector receives the request and creates a BigQuery job.
Although they normally run inside the control plane, you can also run connectors locally β as long as the required secrets/credentials are present β because connectors are just Python services that can be spawned inβprocess.
Connectors are designed to scale horizontally and reduce load on the core Flyte backend because they execute *outside* the core system. This decoupling makes connectors efficient, resilient, and easy to iterate on. You can even test them locally without modifying backend configuration, which reduces friction during development.
## Creating a new connector
If none of the existing connectors meet your needs, you can build your own.
> [!NOTE]
> Connectors communicate via Protobuf, so in theory they can be implemented in any language.
> Today, only **Python** connectors are supported.
### Async connector interface
To implement a new async connector, extend `AsyncConnector` and implement the following methods β all of which **must be idempotent**:
| Method | Purpose |
|----------|-------------------------------------------------------------|
| `create` | Launch the external job (via REST, gRPC, SDK, or other API) |
| `get` | Fetch current job state (return job status or output) |
| `delete` | Delete / cancel the external job |
To test the connector locally, the connector task should inherit from
[AsyncConnectorExecutorMixin](https://github.com/flyteorg/flyte-sdk/blob/1d49299294cd5e15385fe8c48089b3454b7a4cd1/src/flyte/connectors/_connector.py#L206).
This mixin simulates how the Flyte system executes asynchronous connector tasks, making it easier to validate your connector implementation before deploying it.
```python
from dataclasses import dataclass
from flyte.connectors import AsyncConnector, Resource, ResourceMeta
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl2.core.tasks_pb2 import TaskTemplate
from google.protobuf import json_format
import typing
import httpx
@dataclass
class ModelTrainJobMeta(ResourceMeta):
job_id: str
endpoint: str
class ModelTrainingConnector(AsyncConnector):
"""
Example connector that launches a ML model training job on an external training service.
POST β launch training job
GET β poll training progress
DELETE β cancel training job
"""
name = "Model Training Connector"
task_type_name = "external_model_training"
metadata_type = ModelTrainJobMeta
async def create(
self,
task_template: TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
**kwargs,
) -> ModelTrainJobMeta:
"""
Submit training job via POST.
Response returns job_id we later use in get().
"""
custom = json_format.MessageToDict(task_template.custom) if task_template.custom else None
async with httpx.AsyncClient() as client:
r = await client.post(
custom["endpoint"],
json={"dataset_uri": inputs["dataset_uri"], "epochs": inputs["epochs"]},
)
r.raise_for_status()
return ModelTrainJobMeta(job_id=r.json()["job_id"], endpoint=custom["endpoint"])
async def get(self, resource_meta: ModelTrainJobMeta, **kwargs) -> Resource:
"""
Poll external API until training job finishes.
Must be safe to call repeatedly.
"""
async with httpx.AsyncClient() as client:
r = await client.get(f"{resource_meta.endpoint}/{resource_meta.job_id}")
if r.status_code != 200:
return Resource(phase=TaskExecution.RUNNING)
data = r.json()
if data["status"] == "finished":
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="training-dashboard", uri=f"https://example-mltrain.com/train/{resource_meta.job_id}")],
outputs={"results": data["results"]},
)
return Resource(phase=TaskExecution.RUNNING)
async def delete(self, resource_meta: ModelTrainJobMeta, **kwargs):
"""
Optionally call DELETE on external API.
Safe even if job already completed.
"""
async with httpx.AsyncClient() as client:
await client.delete(f"{resource_meta.endpoint}/{resource_meta.job_id}")
```
To actually use this connector, you must also define a task whose `task_type` matches the connector.
```python
import flyte.io
from typing import Any, Dict, Optional
from flyte.extend import TaskTemplate
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.models import NativeInterface, SerializationContext
class ModelTrainTask(AsyncConnectorExecutorMixin, TaskTemplate):
_TASK_TYPE = "external_model_training"
def __init__(
self,
name: str,
endpoint: str,
**kwargs,
):
super().__init__(
name=name,
interface=NativeInterface(
inputs={"epochs": int, "dataset_uri": str},
outputs={"results": flyte.io.File},
),
task_type=self._TASK_TYPE,
**kwargs,
)
self.endpoint = endpoint
def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
return {"endpoint": self.endpoint}
```
Here is an example of how to use the `ModelTrainTask`:
```python
import flyte
env = flyte.TaskEnvironment(name="hello_world", resources=flyte.Resources(memory="250Mi"))
model_train_task = ModelTrainTask(
name="model_train",
endpoint="https://example-mltrain.com",
)
@env.task
def data_prep() -> str:
return "gs://my-bucket/dataset.csv"
@env.task
def train_model(epochs: int) -> flyte.io.File:
dataset_uri = data_prep()
return model_train_task(epochs=epochs, dataset_uri=dataset_uri)
```
## Build Connector Docker Image
Build the custom image when you're ready to deploy your connector to your cluster.
To build the Docker image for your connector, run the following script:
```python
import asyncio
from flyte import Image
from flyte.extend import ImageBuildEngine
async def build_flyte_connector_image(
registry: str, name: str, builder: str = "local"
):
"""
Build the SDK default connector image, optionally overriding
the container registry and image name.
Args:
registry: e.g. "ghcr.io/my-org" or "123456789012.dkr.ecr.us-west-2.amazonaws.com".
name: e.g. "my-connector".
builder: e.g. "local" or "remote".
"""
default_image = Image.from_debian_base(registry=registry, name=name).with_pip_packages(
"flyteplugins-connectors[bigquery]", pre=True
)
await ImageBuildEngine.build(default_image, builder=builder)
if __name__ == "__main__":
print("Building connector image...")
asyncio.run(build_flyte_connector_image(registry="", name="flyte-connectors", builder="local"))
```
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins ===
# Flyte plugins
Flyte is designed to be extensible, allowing you to integrate new tools and frameworks into your workflows. By installing and configuring plugins, you can tailor Flyte to your data and compute ecosystem β whether you need to run large-scale distributed training, process data with a specific engine, or interact with external APIs.
Common reasons to extend Flyte include:
- **Specialized compute:** Use plugins like Spark or Ray to create distributed compute clusters.
- **AI integration:** Connect Flyte with frameworks like OpenAI to run LLM agentic applications.
- **Custom infrastructure:** Add plugins to interface with your organizationβs storage, databases, or proprietary systems.
For example, you can install the PyTorch plugin to run distributed PyTorch jobs natively on a Kubernetes cluster.
| Plugin | Description |
| ------ | ----------- |
| **Flyte plugins > Ray** | Run Ray jobs on your Flyte cluster |
| **Flyte plugins > Spark** | Run Spark jobs on your Flyte cluster |
| **Flyte plugins > OpenAI** | Integrate with OpenAI SDKs in your Flyte workflows |
| **Flyte plugins > Dask** | Run Dask jobs on your Flyte cluster |
## Subpages
- **Flyte plugins > Dask**
- **Flyte plugins > OpenAI**
- **Flyte plugins > Pytorch**
- **Flyte plugins > Ray**
- **Flyte plugins > Spark**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/dask ===
# Dask
Flyte can execute Dask jobs natively on a Kubernetes Cluster,
which manages a clusterβs lifecycle, spin-up, and tear down. It leverages
the open-sourced Dask Kubernetes Operator and can be enabled without signing up for
any service. This is like running a transient Dask cluster β a type of cluster
spun up for a specific Dask job and torn down after completion.
To install the plugin, run the following command:
## Install the plugin
To install the Dask plugin, run the following command:
```shell
$ pip install --pre flyteplugins-dask
```
The following example shows how to configure Dask in a `TaskEnvironment`. Flyte automatically provisions a Dask cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-dask",
# "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///
import asyncio
import typing
from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
import flyte.remote
import flyte.storage
from flyte import Resources
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
task_env = flyte.TaskEnvironment(
name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
resources=Resources(cpu="1", memory="1Gi"),
depends_on=[task_env],
)
@task_env.task()
async def hello_dask():
await asyncio.sleep(5)
print("Hello from the Dask task!")
@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
print("running dask task")
t = asyncio.create_task(hello_dask())
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_dask_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/dask/dask_example.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/openai ===
# OpenAI
Flyte can integrate with OpenAI SDKs in your Flyte workflows.
It provides drop-in replacements for OpenAI SDKs like `openai-agents` so that
you can build LLM-augmented workflows and agentic applications on Flyte.
## Install the plugin
To install the OpenAI plugin, run the following command:
```bash
pip install --pre flyteplugins-openai
```
## Subpages
- **Flyte plugins > OpenAI > Agent tools**
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/openai/agent_tools ===
# Agent tools
In this example, we will use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
This example is based on the [basic tools example](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tools.py) example from the `openai-agents-python` repo.
First, create an OpenAI API key, which you can get from the [OpenAI website](https://platform.openai.com/account/api-keys).
Then, create a secret on your Flyte cluster with:
```
flyte create secret OPENAI_API_KEY --value
```
Then, we'll use `uv script` to specify our dependencies.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
Next, we'll import the libraries and create a `TaskEnvironment`, which we need to run the example:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Define the tools
We'll define a tool that can get weather information for a
given city. In this case, we'll use a toy function that returns a hard-coded `Weather` object.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
In this code snippet, the `@function_tool` decorator is imported from `flyteplugins.openai.agents`, which is a drop-in replacement for the `@function_tool` decorator from `openai-agents` library.
## Define the agent
Then, we'll define the agent, which calls the tool:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Run the agent
Finally, we'll run the agent. Create `config.yaml` file, which the `flyte.init_from_config()` function will use to connect to
the Flyte cluster:
```bash
flyte create config \
--output ~/.flyte/config.yaml \
--endpoint demo.hosted.unionai.cloud/ \
--project flytesnacks \
--domain development \
--builder remote
```
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Conclusion
In this example, we've seen how to use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
The full code is available [here](https://github.com/unionai/unionai-examples/tree/main/v2/integrations/flyte-plugins/openai/openai).
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/pytorch ===
# Pytorch
Flyte can execute distributed PyTorch jobs (which is similar to Running a torchrun script) natively on a Kubernetes Cluster,
which manages a clusterβs lifecycle, spin-up, and tear down.
It leverages the open-sourced Kubeflow Operator.
This is like running a transient Pytorch cluster β a type of cluster
spun up for a specific Pytorch job and torn down after completion.
To install the plugin, run the following command:
```shell
$ pip install --pre flyteplugins-pytorch
```
The following example shows how to configure Pytorch in a `TaskEnvironment`. Flyte automatically provisions a Pytorch cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-pytorch",
# "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///
import typing
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import flyte
image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
image=image,
)
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
"""
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
dataset = TensorDataset(x_train, y_train)
# Distributed-aware sampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
def train_loop(epochs: int = 3) -> float:
"""
A simple training loop for linear regression.
"""
torch.distributed.init_process_group("gloo")
model = DDP(LinearRegressionModel())
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
dataloader = prepare_dataloader(
rank=rank,
world_size=world_size,
batch_size=64,
)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
final_loss = 0.0
for _ in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")
return final_loss
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
A nested task that sets up a simple distributed training job using PyTorch's
"""
print("starting launcher")
loss = train_loop(epochs=epochs)
print("Training complete")
return loss
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(torch_distributed_train, epochs=3)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/ray ===
# Ray
Flyte can execute Ray jobs natively on a Kubernetes Cluster,
which manages a virtual clusterβs lifecycle, spin-up, and tear down.
It leverages the open-sourced https://github.com/ray-project/kuberay and can be
enabled without signing up for any service. This is like running a transient Ray
cluster β a type of cluster spun up for a specific Ray job and torn down after
completion.
To install the plugin, run the following command:
## Install the plugin
To install the Ray plugin, run the following command:
```shell
$ pip install --pre flyteplugins-ray
```
The following example shows how to configure Ray in a `TaskEnvironment`. Flyte automatically provisions a Ray cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///
import asyncio
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.remote
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)
task_env = flyte.TaskEnvironment(
name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray():
await asyncio.sleep(20)
print("Hello from the Ray task!")
@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
print("running ray task")
t = asyncio.create_task(hello_ray())
futures = [f.remote(i) for i in range(n)]
res = ray.get(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_ray_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_example.py)
The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///
import os
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=3600,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
task_env = flyte.TaskEnvironment(
name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_cluster",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
"""
Run a simple Ray task that connects to an existing Ray cluster.
"""
ray.init(address=f"ray://{cluster_ip}:10001")
futures = [f.remote(i) for i in range(5)]
res = ray.get(futures)
return res
@ray_env.task
async def create_ray_cluster() -> str:
"""
Create a Ray cluster and return the head node IP address.
"""
print("creating ray cluster")
cluster_ip = os.getenv("MY_POD_IP")
if cluster_ip is None:
raise ValueError("MY_POD_IP environment variable is not set")
return f"{cluster_ip}"
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(create_ray_cluster)
run.wait()
print("run url:", run.url)
print("cluster created, running ray task")
print("ray address:", run.outputs()[0])
run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
print("run url:", run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_existing_example.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/integrations/flyte-plugins/spark ===
# Spark
Flyte can execute Spark jobs natively on a Kubernetes Cluster,
which manages a virtual clusterβs lifecycle, spin-up, and tear down. It leverages
the open-sourced Spark On K8s Operator and can be enabled without signing up for
any service. This is like running a transient Spark cluster β a type of cluster
spun up for a specific Spark job and torn down after completion.
To install the plugin, run the following command:
```bash
pip install --pre flyteplugins-spark
```
The following example shows how to configure Spark in a `TaskEnvironment`. Flyte automatically provisions a Spark cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///
import random
from copy import deepcopy
from operator import add
from flyteplugins.spark.task import Spark
import flyte.remote
from flyte._context import internal_ctx
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
.with_pip_packages("flyteplugins-spark", pre=True)
)
task_env = flyte.TaskEnvironment(
name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
"spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
"spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
plugin_config=spark_conf,
image=image,
depends_on=[task_env],
)
def f(_):
x = random.random() * 2 - 1
y = random.random() * 2 - 1
return 1 if x**2 + y**2 <= 1 else 0
@task_env.task
async def get_pi(count: int, partitions: int) -> float:
return 4.0 * count / partitions
@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
n = 1 * partitions
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
return await get_pi(count, partitions)
@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
updated_spark_conf = deepcopy(spark_conf)
updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_spark_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/spark/spark_example.py)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference ===
# Reference
This section provides the reference material for the Flyte SDK and CLI.
To get started, add `flyte` to your project
```shell
$ uv pip install --no-cache --prerelease=allow --upgrade flyte
```
This will install both the Flyte SDK and CLI.
### π **Flyte SDK**
The Flyte SDK provides the core Python API for building workflows and apps on your Union instance.
### π **Flyte CLI**
The Flyte CLI is the command-line interface for interacting with your Union instance.
## Subpages
- **Flyte CLI**
- **LLM context document**
- **Flyte SDK**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-cli ===
# Flyte CLI
This is the command line interface for Flyte.
| Object | Action |
| ------ | -- |
| `run` | **Flyte CLI > flyte > flyte abort > flyte abort run**, **Flyte CLI > flyte > flyte get > flyte get run** |
| `api-key` | **Flyte CLI > flyte > flyte create > flyte create api-key**, **Flyte CLI > flyte > flyte delete > flyte delete api-key**, **Flyte CLI > flyte > flyte get > flyte get api-key** |
| `config` | **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte get > flyte get config** |
| `secret` | **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte get > flyte get secret** |
| `trigger` | **Flyte CLI > flyte > flyte create > flyte create trigger**, **Flyte CLI > flyte > flyte delete > flyte delete trigger**, **Flyte CLI > flyte > flyte get > flyte get trigger**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| `docs` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `action` | **Flyte CLI > flyte > flyte get > flyte get action** |
| `app` | **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte update > flyte update app** |
| `io` | **Flyte CLI > flyte > flyte get > flyte get io** |
| `logs` | **Flyte CLI > flyte > flyte get > flyte get logs** |
| `project` | **Flyte CLI > flyte > flyte get > flyte get project** |
| `task` | **Flyte CLI > flyte > flyte get > flyte get task** |
| `hf-model` | **Flyte CLI > flyte > flyte prefetch > flyte prefetch hf-model** |
| `deployed-task` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
**βΊ** Plugin command - see command documentation for installation instructions
| Action | On |
| ------ | -- |
| `abort` | **Flyte CLI > flyte > flyte abort > flyte abort run** |
| **Flyte CLI > flyte > flyte build** | - |
| `create` | **Flyte CLI > flyte > flyte create > flyte create api-key**, **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte create > flyte create trigger** |
| `delete` | **Flyte CLI > flyte > flyte delete > flyte delete api-key**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte delete > flyte delete trigger** |
| **Flyte CLI > flyte > flyte deploy** | - |
| `gen` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `get` | **Flyte CLI > flyte > flyte get > flyte get action**, **Flyte CLI > flyte > flyte get > flyte get api-key**, **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte get > flyte get config**, **Flyte CLI > flyte > flyte get > flyte get io**, **Flyte CLI > flyte > flyte get > flyte get logs**, **Flyte CLI > flyte > flyte get > flyte get project**, **Flyte CLI > flyte > flyte get > flyte get run**, **Flyte CLI > flyte > flyte get > flyte get secret**, **Flyte CLI > flyte > flyte get > flyte get task**, **Flyte CLI > flyte > flyte get > flyte get trigger** |
| `prefetch` | **Flyte CLI > flyte > flyte prefetch > flyte prefetch hf-model** |
| `run` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
| **Flyte CLI > flyte > flyte serve** | - |
| `update` | **Flyte CLI > flyte > flyte update > flyte update app**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| **Flyte CLI > flyte > flyte whoami** | - |
**βΊ** Plugin command - see command documentation for installation instructions
## flyte
**`flyte [OPTIONS] COMMAND [ARGS]...`**
The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
It follows a simple verb/noun structure,
where the top-level commands are verbs that describe the action to be taken,
and the subcommands are nouns that describe the object of the action.
The root command can be used to configure the CLI for persistent settings,
such as the endpoint, organization, and verbosity level.
Set endpoint and organization:
```bash
$ flyte --endpoint --org get project
```
Increase verbosity level (This is useful for debugging,
this will show more logs and exception traces):
```bash
$ flyte -vvv get logs
```
Override the default config file:
```bash
$ flyte --config /path/to/config.yaml run ...
```
* [Documentation](https://www.union.ai/docs/flyte/user-guide/)
* [GitHub](https://github.com/flyteorg/flyte): Please leave a star if you like Flyte!
* [Slack](https://slack.flyte.org): Join the community and ask questions.
* [Issues](https://github.com/flyteorg/flyte/issues)
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--version` | `boolean` | `False` | Show the version and exit. |
| `--endpoint` | `text` | `Sentinel.UNSET` | The endpoint to connect to. This will override any configuration file and simply use `pkce` to connect. |
| `--insecure` | `boolean` | | Use an insecure connection to the endpoint. If not specified, the CLI will use TLS. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `-v`
`--verbose` | `integer` | `0` | Show verbose messages and exception traces. Repeating multiple times increases the verbosity (e.g., -vvv). |
| `--org` | `text` | `Sentinel.UNSET` | The organization to which the command applies. |
| `-c`
`--config` | `path` | `Sentinel.UNSET` | Path to the configuration file to use. If not specified, the default configuration file is used. |
| `--output-format`
`-of` | `choice` | `table` | Output format for commands that support it. Defaults to 'table'. |
| `--log-format` | `choice` | `console` | Formatting for logs, defaults to 'console' which is meant to be human readable. 'json' is meant for machine parsing. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte abort
**`flyte abort COMMAND [ARGS]...`**
Abort an ongoing process.
#### flyte abort run
**`flyte abort run [OPTIONS] RUN_NAME`**
Abort a run.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte build
**`flyte build [OPTIONS] COMMAND [ARGS]...`**
Build the environments defined in a python file or directory. This will build the images associated with the
environments.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--noop` | `boolean` | `Sentinel.UNSET` | Dummy parameter, placeholder for future use. Does not affect the build process. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte create
**`flyte create COMMAND [ARGS]...`**
Create resources in a Flyte deployment.
#### flyte create api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte create api-key [OPTIONS]`**
Create an API key for headless authentication.
This creates OAuth application credentials that can be used to authenticate
with Union without interactive login. The generated API key should be set
as the FLYTE_API_KEY environment variable. Oauth applications should not be
confused with Union Apps, which are a different construct entirely.
Examples:
# Create an API key named "ci-pipeline"
$ flyte create api-key --name ci-pipeline
# The output will include an export command like:
# export FLYTE_API_KEY=""
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--name` | `text` | `Sentinel.UNSET` | Name for API key |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create config
**`flyte create config [OPTIONS]`**
Creates a configuration file for Flyte CLI.
If the `--output` option is not specified, it will create a file named `config.yaml` in the current directory.
If the file already exists, it will raise an error unless the `--force` option is used.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--endpoint` | `text` | `Sentinel.UNSET` | Endpoint of the Flyte backend. |
| `--insecure` | `boolean` | `False` | Use an insecure connection to the Flyte backend. |
| `--org` | `text` | `Sentinel.UNSET` | Organization to use. This will override the organization in the configuration file. |
| `-o`
`--output` | `path` | `.flyte/config.yaml` | Path to the output directory where the configuration will be saved. Defaults to current directory. |
| `--force` | `boolean` | `False` | Force overwrite of the configuration file if it already exists. |
| `--image-builder`
`--builder` | `choice` | `local` | Image builder to use for building images. Defaults to 'local'. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create secret
**`flyte create secret [OPTIONS] NAME`**
Create a new secret. The name of the secret is required. For example:
```bash
$ flyte create secret my_secret --value my_value
```
If you don't provide a `--value` flag, you will be prompted to enter the
secret value in the terminal.
```bash
$ flyte create secret my_secret
Enter secret value:
```
If `--from-file` is specified, the value will be read from the file instead of being provided directly:
```bash
$ flyte create secret my_secret --from-file /path/to/secret_file
```
The `--type` option can be used to create specific types of secrets.
Either `regular` or `image_pull` can be specified.
Secrets intended to access container images should be specified as `image_pull`.
Other secrets should be specified as `regular`.
If no type is specified, `regular` is assumed.
For image pull secrets, you have several options:
1. Interactive mode (prompts for registry, username, password):
```bash
$ flyte create secret my_secret --type image_pull
```
2. With explicit credentials:
```bash
$ flyte create secret my_secret --type image_pull --registry ghcr.io --username myuser
```
3. Lastly, you can create a secret from your existing Docker installation (i.e., you've run `docker login` in
the past) and you just want to pull from those credentials. Since you may have logged in to multiple registries,
you can specify which registries to include. If no registries are specified, all registries are added.
```bash
$ flyte create secret my_secret --type image_pull --from-docker-config --registries ghcr.io,docker.io
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--value` | `text` | `Sentinel.UNSET` | Secret value Mutually exclusive with from_file, from_docker_config, registry. |
| `--from-file` | `path` | `Sentinel.UNSET` | Path to the file with the binary secret. Mutually exclusive with value, from_docker_config, registry. |
| `--type` | `choice` | `regular` | Type of the secret. |
| `--from-docker-config` | `boolean` | `False` | Create image pull secret from Docker config file (only for --type image_pull). Mutually exclusive with value, from_file, registry, username, password. |
| `--docker-config-path` | `path` | `Sentinel.UNSET` | Path to Docker config file (defaults to ~/.docker/config.json or $DOCKER_CONFIG). |
| `--registries` | `text` | `Sentinel.UNSET` | Comma-separated list of registries to include (only with --from-docker-config). |
| `--registry` | `text` | `Sentinel.UNSET` | Registry hostname (e.g., ghcr.io, docker.io) for explicit credentials (only for --type image_pull). Mutually exclusive with value, from_file, from_docker_config. |
| `--username` | `text` | `Sentinel.UNSET` | Username for the registry (only with --registry). |
| `--password` | `text` | `Sentinel.UNSET` | Password for the registry (only with --registry). If not provided, will prompt. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create trigger
**`flyte create trigger [OPTIONS] TASK_NAME NAME`**
Create a new trigger for a task. The task name and trigger name are required.
Example:
```bash
$ flyte create trigger my_task my_trigger --schedule "0 0 * * *"
```
This will create a trigger that runs every day at midnight.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--schedule` | `text` | `Sentinel.UNSET` | Cron schedule for the trigger. Defaults to every minute. |
| `--description` | `text` | `` | Description of the trigger. |
| `--auto-activate` | `boolean` | `True` | Whether the trigger should not be automatically activated. Defaults to True. |
| `--trigger-time-var` | `text` | `trigger_time` | Variable name for the trigger time in the task inputs. Defaults to 'trigger_time'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte delete
**`flyte delete COMMAND [ARGS]...`**
Remove resources from a Flyte deployment.
#### flyte delete api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte delete api-key [OPTIONS] CLIENT_ID`**
Delete an API key.
Examples:
# Delete an API key (with confirmation)
$ flyte delete api-key my-client-id
# Delete without confirmation
$ flyte delete api-key my-client-id --yes
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--yes` | `boolean` | `False` | Skip confirmation prompt |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete secret
**`flyte delete secret [OPTIONS] NAME`**
Delete a secret. The name of the secret is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete trigger
**`flyte delete trigger [OPTIONS] NAME TASK_NAME`**
Delete a trigger. The name of the trigger is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte deploy
**`flyte deploy [OPTIONS] COMMAND [ARGS]...`**
Deploy one or more environments from a python file.
This command will create or update environments in the Flyte system, registering
all tasks and their dependencies.
Example usage:
```bash
flyte deploy hello.py my_env
```
Arguments to the deploy command are provided right after the `deploy` command and before the file name.
To deploy all environments in a file, use the `--all` flag:
```bash
flyte deploy --all hello.py
```
To recursively deploy all environments in a directory and its subdirectories, use the `--recursive` flag:
```bash
flyte deploy --recursive ./src
```
You can combine `--all` and `--recursive` to deploy everything:
```bash
flyte deploy --all --recursive ./src
```
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte deploy --image my_image=ghcr.io/myorg/my-image:v1.0 hello.py my_env
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
```bash
flyte deploy --image ghcr.io/myorg/default-image:latest hello.py my_env
```
You can specify multiple image arguments:
```bash
flyte deploy --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 hello.py my_env
```
To deploy a specific version, use the `--version` flag:
```bash
flyte deploy --version v1.0.0 hello.py my_env
```
To preview what would be deployed without actually deploying, use the `--dry-run` flag:
```bash
flyte deploy --dry-run hello.py my_env
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte --config my-config.yaml deploy hello.py my_env
```
You can override the default configured project and domain:
```bash
flyte deploy --project my-project --domain development hello.py my_env
```
If loading some files fails during recursive deployment, you can use the `--ignore-load-errors` flag
to continue deploying the environments that loaded successfully:
```bash
flyte deploy --recursive --ignore-load-errors ./src
```
Other arguments to the deploy command are listed below.
To see the environments available in a file, use `--help` after the file name:
```bash
flyte deploy hello.py --help
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--version` | `text` | `Sentinel.UNSET` | Version of the environment to deploy |
| `--dry-run`
`--dryrun` | `boolean` | `False` | Dry run. Do not actually call the backend service. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--recursive`
`-r` | `boolean` | `False` | Recursively deploy all environments in the current directory |
| `--all` | `boolean` | `False` | Deploy all environments in the current directory, ignoring the file name |
| `--ignore-load-errors`
`-i` | `boolean` | `False` | Ignore errors when loading environments especially when using --recursive or --all. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte gen
**`flyte gen COMMAND [ARGS]...`**
Generate documentation.
#### flyte gen docs
**`flyte gen docs [OPTIONS]`**
Generate documentation.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--type` | `text` | `Sentinel.UNSET` | Type of documentation (valid: markdown) |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte get
**`flyte get COMMAND [ARGS]...`**
Retrieve resources from a Flyte deployment.
You can get information about projects, runs, tasks, actions, secrets, logs and input/output values.
Each command supports optional parameters to filter or specify the resource you want to retrieve.
Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
For example:
* `get project` (without specifying a project), will list all projects.
* `get project my_project` will return the details of the project named `my_project`.
In some cases, a partially specified command will act as a filter and return available further parameters.
For example:
* `get action my_run` will return all actions for the run named `my_run`.
* `get action my_run my_action` will return the details of the action named `my_action` for the run `my_run`.
#### flyte get action
**`flyte get action [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get all actions for a run or details for a specific action.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--in-phase` | `choice` | `Sentinel.UNSET` | Filter actions by their phase. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte get api-key [OPTIONS] [CLIENT_ID]`**
Get or list API keys.
If CLIENT-ID is provided, gets a specific API key.
Otherwise, lists all API keys.
Examples:
# List all API keys
$ flyte get api-key
# List with a limit
$ flyte get api-key --limit 10
# Get a specific API key
$ flyte get api-key my-client-id
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Maximum number of keys to list |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get app
**`flyte get app [OPTIONS] [NAME]`**
Get a list of all apps, or details of a specific app by name.
Apps are long-running services deployed on the Flyte platform.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of apps to fetch when listing. |
| `--only-mine` | `boolean` | `False` | Show only apps created by the current user (you). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get config
**`flyte get config`**
Shows the automatically detected configuration to connect with the remote backend.
The configuration will include the endpoint, organization, and other settings that are used by the CLI.
#### flyte get io
**`flyte get io [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get the inputs and outputs of a run or action.
If only the run name is provided, it will show the inputs and outputs of the root action of that run.
If an action name is provided, it will show the inputs and outputs for that action.
If `--inputs-only` or `--outputs-only` is specified, it will only show the inputs or outputs respectively.
Examples:
```bash
$ flyte get io my_run
```
```bash
$ flyte get io my_run my_action
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--inputs-only`
`-i` | `boolean` | `False` | Show only inputs |
| `--outputs-only`
`-o` | `boolean` | `False` | Show only outputs |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get logs
**`flyte get logs [OPTIONS] RUN_NAME [ACTION_NAME]`**
Stream logs for the provided run or action.
If only the run is provided, only the logs for the parent action will be streamed:
```bash
$ flyte get logs my_run
```
If you want to see the logs for a specific action, you can provide the action name as well:
```bash
$ flyte get logs my_run my_action
```
By default, logs will be shown in the raw format and will scroll the terminal.
If automatic scrolling and only tailing `--lines` number of lines is desired, use the `--pretty` flag:
```bash
$ flyte get logs my_run my_action --pretty --lines 50
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--lines`
`-l` | `integer` | `30` | Number of lines to show, only useful for --pretty |
| `--show-ts` | `boolean` | `False` | Show timestamps |
| `--pretty` | `boolean` | `False` | Show logs in an auto-scrolling box, where number of lines is limited to `--lines` |
| `--attempt`
`-a` | `integer` | | Attempt number to show logs for, defaults to the latest attempt. |
| `--filter-system` | `boolean` | `False` | Filter all system logs from the output. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get project
**`flyte get project [NAME]`**
Get a list of all projects, or details of a specific project by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get run
**`flyte get run [OPTIONS] [NAME]`**
Get a list of all runs, or details of a specific run by name.
The run details will include information about the run, its status, but only the root action will be shown.
If you want to see the actions for a run, use `get action `.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of runs to fetch when listing. |
| `--in-phase` | `choice` | `Sentinel.UNSET` | Filter runs by their status. |
| `--only-mine` | `boolean` | `False` | Show only runs created by the current user (you). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get secret
**`flyte get secret [OPTIONS] [NAME]`**
Get a list of all secrets, or details of a specific secret by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get task
**`flyte get task [OPTIONS] [NAME] [VERSION]`**
Retrieve a list of all tasks, or details of a specific task by name and version.
Currently, both `name` and `version` are required to get a specific task.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of tasks to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get trigger
**`flyte get trigger [OPTIONS] [TASK_NAME] [NAME]`**
Get a list of all triggers, or details of a specific trigger by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of triggers to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte prefetch
**`flyte prefetch COMMAND [ARGS]...`**
Prefetch artifacts from remote registries.
These commands help you download and prefetch artifacts like HuggingFace models
to your Flyte storage for faster access during task execution.
#### flyte prefetch hf-model
**`flyte prefetch hf-model [OPTIONS] REPO`**
Prefetch a HuggingFace model to Flyte storage.
Downloads a model from the HuggingFace Hub and prefetches it to your configured
Flyte storage backend. This is useful for:
- Pre-fetching large models before running inference tasks
- Sharding models for tensor-parallel inference
- Avoiding repeated downloads during development
**Basic Usage:**
```bash
$ flyte prefetch hf-model meta-llama/Llama-2-7b-hf --hf-token-key HF_TOKEN
```
**With Sharding:**
Create a shard config file (shard_config.yaml):
```yaml
engine: vllm
args:
tensor_parallel_size: 8
dtype: auto
trust_remote_code: true
```
Then run:
```bash
$ flyte prefetch hf-model meta-llama/Llama-2-70b-hf \
--shard-config shard_config.yaml \
--accelerator A100:8 \
--hf-token-key HF_TOKEN
```
**Wait for Completion:**
```bash
$ flyte prefetch hf-model meta-llama/Llama-2-7b-hf --wait
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--raw-data-path` | `text` | | Object store path to store the model. If not provided, the model will be stored using the default path generated by Flyte storage layer. |
| `--artifact-name` | `text` | | Artifact name to use for the stored model. Must only contain alphanumeric characters, underscores, and hyphens. If not provided, the repo name will be used (replacing '.' with '-'). |
| `--architecture` | `text` | `Sentinel.UNSET` | Model architecture, as given in HuggingFace config.json. |
| `--task` | `text` | `auto` | Model task, e.g., 'generate', 'classify', 'embed', 'score', etc. Refer to vLLM docs. 'auto' will try to discover this automatically. |
| `--modality` | `text` | `('text',)` | Modalities supported by the model, e.g., 'text', 'image', 'audio', 'video'. Can be specified multiple times. |
| `--format` | `text` | `Sentinel.UNSET` | Model serialization format, e.g., safetensors, onnx, torchscript, joblib, etc. |
| `--model-type` | `text` | `Sentinel.UNSET` | Model type, e.g., 'transformer', 'xgboost', 'custom', etc. For HuggingFace models, this is auto-determined from config.json['model_type']. |
| `--short-description` | `text` | `Sentinel.UNSET` | Short description of the model. |
| `--force` | `integer` | `0` | Force store of the model. Increment value (--force=1, --force=2, ...) to force a new store. |
| `--wait` | `boolean` | `False` | Wait for the model to be stored before returning. |
| `--hf-token-key` | `text` | `HF_TOKEN` | Name of the Flyte secret containing your HuggingFace token. Note: This is not the HuggingFace token itself, but the name of the secret in the Flyte secret store. |
| `--cpu` | `text` | `2` | CPU request for the prefetch task (e.g., '2', '4', '2,4' for 2-4 CPUs). |
| `--mem` | `text` | `8Gi` | Memory request for the prefetch task (e.g., '16Gi', '64Gi', '16Gi,64Gi' for 16-64GB). |
| `--gpu` | `choice` | | The gpu to use for downloading and (optionally) sharding the model. Format: '{type}:{quantity}' (e.g., 'A100:8', 'L4:1'). |
| `--disk` | `text` | `50Gi` | Disk storage request for the prefetch task (e.g., '100Gi', '500Gi'). |
| `--shm` | `text` | | Shared memory request for the prefetch task (e.g., '100Gi', 'auto'). |
| `--shard-config` | `path` | `Sentinel.UNSET` | Path to a YAML file containing sharding configuration. The file should have 'engine' (currently only 'vllm') and 'args' keys. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte run
**`flyte run [OPTIONS] COMMAND [ARGS]...`**
Run a task from a python file or deployed task.
Example usage:
```bash
flyte run hello.py my_task --arg1 value1 --arg2 value2
```
Arguments to the run command are provided right after the `run` command and before the file name.
Arguments for the task itself are provided after the task name.
To run a task locally, use the `--local` flag. This will run the task in the local environment instead of the remote
Flyte environment:
```bash
flyte run --local hello.py my_task --arg1 value1 --arg2 value2
```
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte run --image my_image=ghcr.io/myorg/my-image:v1.0 hello.py my_task
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
```bash
flyte run --image ghcr.io/myorg/default-image:latest hello.py my_task
```
You can specify multiple image arguments:
```bash
flyte run --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 hello.py my_task
```
To run tasks that you've already deployed to Flyte, use the deployed-task command:
```bash
flyte run deployed-task my_env.my_task --arg1 value1 --arg2 value2
```
To run a specific version of a deployed task, use the `env.task:version` syntax:
```bash
flyte run deployed-task my_env.my_task:xyz123 --arg1 value1 --arg2 value2
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte run --config my-config.yaml deployed-task ...
```
You can override the default configured project and domain:
```bash
flyte run --project my-project --domain development hello.py my_task
```
You can discover what deployed tasks are available by running:
```bash
flyte run deployed-task
```
Other arguments to the run command are listed below.
Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
```bash
flyte run hello.py my_task --help
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--local` | `boolean` | `False` | Run the task locally |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--raw-data-path` | `text` | `Sentinel.UNSET` | Override the output prefix used to store offloaded data types. e.g. s3://bucket/ |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the run. If not provided, a random name will be generated. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the parent action. If not provided, the CLI will exit after successfully launching a remote execution with a link to the UI. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte run deployed-task
**`flyte run deployed-task [OPTIONS] COMMAND [ARGS]...`**
Run reference task from the Flyte backend
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte serve
**`flyte serve [OPTIONS] COMMAND [ARGS]...`**
Serve an app from a Python file using flyte.serve().
This command allows you to serve apps defined with `flyte.app.AppEnvironment`
in your Python files. The serve command will deploy the app to the Flyte backend
and start it, making it accessible via a URL.
Example usage:
```bash
flyte serve examples/apps/basic_app.py app_env
```
Arguments to the serve command are provided right after the `serve` command and before the file name.
To follow the logs of the served app, use the `--follow` flag:
```bash
flyte serve --follow examples/apps/basic_app.py app_env
```
Note: Log streaming is not yet fully implemented and will be added in a future release.
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the app environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte serve --image my_image=ghcr.io/myorg/my-image:v1.0 examples/apps/basic_app.py app_env
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in AppEnvironment:
```bash
flyte serve --image ghcr.io/myorg/default-image:latest examples/apps/basic_app.py app_env
```
You can specify multiple image arguments:
```bash
flyte serve --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 examples/apps/basic_app.py app_env
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte serve --config my-config.yaml examples/apps/basic_app.py app_env
```
You can override the default configured project and domain:
```bash
flyte serve --project my-project --domain development examples/apps/basic_app.py app_env
```
Other arguments to the serve command are listed below.
Note: This pattern is primarily useful for serving apps defined in tasks.
Serving deployed apps is not currently supported through this CLI command.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when serving the app |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the app deployment. If not provided, the app environment name will be used. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the app. If not provided, the CLI will exit after successfully deploying the app with a link to the UI. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the serve. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--env-var`
`-e` | `text` | `Sentinel.UNSET` | Environment variable to set in the app. Format: KEY=VALUE. Can be specified multiple times. Example: --env-var LOG_LEVEL=DEBUG --env-var DATABASE_URL=postgresql://... |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte update
**`flyte update COMMAND [ARGS]...`**
Update various flyte entities.
#### flyte update app
**`flyte update app [OPTIONS] NAME`**
Update an app by starting or stopping it.
Example usage:
```bash
flyte update app --activate | --deactivate [--wait] [--project ] [--domain ]
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | | Activate or deactivate app. |
| `--wait` | `boolean` | `False` | Wait for the app to reach the desired state. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte update trigger
**`flyte update trigger [OPTIONS] NAME TASK_NAME`**
Update a trigger.
Example usage:
```bash
flyte update trigger --activate | --deactivate
[--project --domain ]
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | `Sentinel.UNSET` | Activate or deactivate the trigger. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte whoami
**`flyte whoami`**
Display the current user information.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-context ===
# LLM context document
The following document provides a LLM context for authoring and running Flyte/Union workflows.
They can serve as a reference for LLM-based AI assistants to understand how to properly write, configure, and execute Flyte/Union workflows.
* **Full documentation content**: The entire documentation (this site) for Flyte version 2.0 in a single text file.
* π₯ [llms-full.txt](/_static/public/llms-full.txt)
You can add it to the context window of your LLM-based AI assistant to help it better understand Flyte/Union development.
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk ===
# Flyte SDK
These are the docs for Flyte SDK version 2.0
Flyte is the core Python SDK for the Union and Flyte platforms.
## Subpages
- **Flyte SDK > Classes**
- **Flyte SDK > Packages**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/classes ===
# Classes
| Class | Description |
|-|-|
| [`flyte.Cache`](../packages/flyte/cache) |Cache configuration for a task. |
| [`flyte.Cron`](../packages/flyte/cron) |This class defines a Cron automation that can be associated with a Trigger in Flyte. |
| [`flyte.Device`](../packages/flyte/device) |Represents a device type, its quantity and partition if applicable. |
| [`flyte.Environment`](../packages/flyte/environment) | |
| [`flyte.FixedRate`](../packages/flyte/fixedrate) |This class defines a FixedRate automation that can be associated with a Trigger in Flyte. |
| [`flyte.Image`](../packages/flyte/image) |This is a representation of Container Images, which can be used to create layered images programmatically. |
| [`flyte.PodTemplate`](../packages/flyte/podtemplate) |Custom PodTemplate specification for a Task. |
| [`flyte.Resources`](../packages/flyte/resources) |Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| [`flyte.RetryStrategy`](../packages/flyte/retrystrategy) |Retry strategy for the task or task environment. |
| [`flyte.ReusePolicy`](../packages/flyte/reusepolicy) |ReusePolicy can be used to configure a task to reuse the environment. |
| [`flyte.Secret`](../packages/flyte/secret) |Secrets are used to inject sensitive information into tasks or image build context. |
| [`flyte.TaskEnvironment`](../packages/flyte/taskenvironment) |Environment class to define a new environment for a set of tasks. |
| [`flyte.Timeout`](../packages/flyte/timeout) |Timeout class to define a timeout for a task. |
| [`flyte.Trigger`](../packages/flyte/trigger) |This class defines specification of a Trigger, that can be associated with any Flyte V2 task. |
| [`flyte.app.AppEndpoint`](../packages/flyte.app/appendpoint) |Embed an upstream app's endpoint as an app input. |
| [`flyte.app.AppEnvironment`](../packages/flyte.app/appenvironment) | |
| [`flyte.app.Domain`](../packages/flyte.app/domain) |Subdomain to use for the domain. |
| [`flyte.app.Input`](../packages/flyte.app/input) |Input for application. |
| [`flyte.app.Link`](../packages/flyte.app/link) |Custom links to add to the app. |
| [`flyte.app.Port`](../packages/flyte.app/port) | |
| [`flyte.app.RunOutput`](../packages/flyte.app/runoutput) |Use a run's output for app inputs. |
| [`flyte.app.Scaling`](../packages/flyte.app/scaling) | |
| [`flyte.app.extras.FastAPIAppEnvironment`](../packages/flyte.app.extras/fastapiappenvironment) | |
| [`flyte.config.Config`](../packages/flyte.config/config) |This the parent configuration object and holds all the underlying configuration object types. |
| [`flyte.errors.ActionNotFoundError`](../packages/flyte.errors/actionnotfounderror) |This error is raised when the user tries to access an action that does not exist. |
| [`flyte.errors.BaseRuntimeError`](../packages/flyte.errors/baseruntimeerror) |Base class for all Union runtime errors. |
| [`flyte.errors.CustomError`](../packages/flyte.errors/customerror) |This error is raised when the user raises a custom error. |
| [`flyte.errors.DeploymentError`](../packages/flyte.errors/deploymenterror) |This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| [`flyte.errors.ImageBuildError`](../packages/flyte.errors/imagebuilderror) |This error is raised when the image build fails. |
| [`flyte.errors.ImagePullBackOffError`](../packages/flyte.errors/imagepullbackofferror) |This error is raised when the image cannot be pulled. |
| [`flyte.errors.InitializationError`](../packages/flyte.errors/initializationerror) |This error is raised when the Union system is tried to access without being initialized. |
| [`flyte.errors.InlineIOMaxBytesBreached`](../packages/flyte.errors/inlineiomaxbytesbreached) |This error is raised when the inline IO max bytes limit is breached. |
| [`flyte.errors.InvalidImageNameError`](../packages/flyte.errors/invalidimagenameerror) |This error is raised when the image name is invalid. |
| [`flyte.errors.LogsNotYetAvailableError`](../packages/flyte.errors/logsnotyetavailableerror) |This error is raised when the logs are not yet available for a task. |
| [`flyte.errors.ModuleLoadError`](../packages/flyte.errors/moduleloaderror) |This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| [`flyte.errors.NotInTaskContextError`](../packages/flyte.errors/notintaskcontexterror) |This error is raised when the user tries to access the task context outside of a task. |
| [`flyte.errors.OOMError`](../packages/flyte.errors/oomerror) |This error is raised when the underlying task execution fails because of an out-of-memory error. |
| [`flyte.errors.OnlyAsyncIOSupportedError`](../packages/flyte.errors/onlyasynciosupportederror) |This error is raised when the user tries to use sync IO in an async task. |
| [`flyte.errors.PrimaryContainerNotFoundError`](../packages/flyte.errors/primarycontainernotfounderror) |This error is raised when the primary container is not found. |
| [`flyte.errors.ReferenceTaskError`](../packages/flyte.errors/referencetaskerror) |This error is raised when the user tries to access a task that does not exist. |
| [`flyte.errors.RetriesExhaustedError`](../packages/flyte.errors/retriesexhaustederror) |This error is raised when the underlying task execution fails after all retries have been exhausted. |
| [`flyte.errors.RunAbortedError`](../packages/flyte.errors/runabortederror) |This error is raised when the run is aborted by the user. |
| [`flyte.errors.RuntimeDataValidationError`](../packages/flyte.errors/runtimedatavalidationerror) |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| [`flyte.errors.RuntimeSystemError`](../packages/flyte.errors/runtimesystemerror) |This error is raised when the underlying task execution fails because of a system error. |
| [`flyte.errors.RuntimeUnknownError`](../packages/flyte.errors/runtimeunknownerror) |This error is raised when the underlying task execution fails because of an unknown error. |
| [`flyte.errors.RuntimeUserError`](../packages/flyte.errors/runtimeusererror) |This error is raised when the underlying task execution fails because of an error in the user's code. |
| [`flyte.errors.SlowDownError`](../packages/flyte.errors/slowdownerror) |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| [`flyte.errors.TaskInterruptedError`](../packages/flyte.errors/taskinterruptederror) |This error is raised when the underlying task execution is interrupted. |
| [`flyte.errors.TaskTimeoutError`](../packages/flyte.errors/tasktimeouterror) |This error is raised when the underlying task execution runs for longer than the specified timeout. |
| [`flyte.errors.UnionRpcError`](../packages/flyte.errors/unionrpcerror) |This error is raised when communication with the Union server fails. |
| [`flyte.extend.AsyncFunctionTaskTemplate`](../packages/flyte.extend/asyncfunctiontasktemplate) |A task template that wraps an asynchronous functions. |
| [`flyte.extend.ImageBuildEngine`](../packages/flyte.extend/imagebuildengine) |ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| [`flyte.extend.TaskTemplate`](../packages/flyte.extend/tasktemplate) |Task template is a template for a task that can be executed. |
| [`flyte.extras.ContainerTask`](../packages/flyte.extras/containertask) |This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
| [`flyte.git.GitStatus`](../packages/flyte.git/gitstatus) |A class representing the status of a git repository. |
| [`flyte.io.DataFrame`](../packages/flyte.io/dataframe) |This is the user facing DataFrame class. |
| [`flyte.io.DataFrameDecoder`](../packages/flyte.io/dataframedecoder) |Helper class that provides a standard way to create an ABC using. |
| [`flyte.io.DataFrameEncoder`](../packages/flyte.io/dataframeencoder) |Helper class that provides a standard way to create an ABC using. |
| [`flyte.io.DataFrameTransformerEngine`](../packages/flyte.io/dataframetransformerengine) |Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
| [`flyte.io.Dir`](../packages/flyte.io/dir) |A generic directory class representing a directory with files of a specified format. |
| [`flyte.io.File`](../packages/flyte.io/file) |A generic file class representing a file with a specified format. |
| [`flyte.models.ActionID`](../packages/flyte.models/actionid) |A class representing the ID of an Action, nested within a Run. |
| [`flyte.models.ActionPhase`](../packages/flyte.models/actionphase) |Represents the execution phase of a Flyte action (run). |
| [`flyte.models.Checkpoints`](../packages/flyte.models/checkpoints) |A class representing the checkpoints for a task. |
| [`flyte.models.CodeBundle`](../packages/flyte.models/codebundle) |A class representing a code bundle for a task. |
| [`flyte.models.GroupData`](../packages/flyte.models/groupdata) | |
| [`flyte.models.NativeInterface`](../packages/flyte.models/nativeinterface) |A class representing the native interface for a task. |
| [`flyte.models.PathRewrite`](../packages/flyte.models/pathrewrite) |Configuration for rewriting paths during input loading. |
| [`flyte.models.RawDataPath`](../packages/flyte.models/rawdatapath) |A class representing the raw data path for a task. |
| [`flyte.models.SerializationContext`](../packages/flyte.models/serializationcontext) |This object holds serialization time contextual information, that can be used when serializing the task and. |
| [`flyte.models.TaskContext`](../packages/flyte.models/taskcontext) |A context class to hold the current task executions context. |
| [`flyte.prefetch.HuggingFaceModelInfo`](../packages/flyte.prefetch/huggingfacemodelinfo) |Information about a HuggingFace model to store. |
| [`flyte.prefetch.ShardConfig`](../packages/flyte.prefetch/shardconfig) |Configuration for model sharding. |
| [`flyte.prefetch.StoredModelInfo`](../packages/flyte.prefetch/storedmodelinfo) |Information about a stored model. |
| [`flyte.prefetch.VLLMShardArgs`](../packages/flyte.prefetch/vllmshardargs) |Arguments for sharding a model using vLLM. |
| [`flyte.remote.Action`](../packages/flyte.remote/action) |A class representing an action. |
| [`flyte.remote.ActionDetails`](../packages/flyte.remote/actiondetails) |A class representing an action. |
| [`flyte.remote.ActionInputs`](../packages/flyte.remote/actioninputs) |A class representing the inputs of an action. |
| [`flyte.remote.ActionOutputs`](../packages/flyte.remote/actionoutputs) |A class representing the outputs of an action. |
| [`flyte.remote.App`](../packages/flyte.remote/app) |A mixin class that provides a method to convert an object to a JSON-serializable dictionary. |
| [`flyte.remote.Project`](../packages/flyte.remote/project) |A class representing a project in the Union API. |
| [`flyte.remote.Run`](../packages/flyte.remote/run) |A class representing a run of a task. |
| [`flyte.remote.RunDetails`](../packages/flyte.remote/rundetails) |A class representing a run of a task. |
| [`flyte.remote.Secret`](../packages/flyte.remote/secret) | |
| [`flyte.remote.Task`](../packages/flyte.remote/task) | |
| [`flyte.remote.TaskDetails`](../packages/flyte.remote/taskdetails) | |
| [`flyte.remote.Trigger`](../packages/flyte.remote/trigger) | |
| [`flyte.remote.User`](../packages/flyte.remote/user) | |
| [`flyte.report.Report`](../packages/flyte.report/report) | |
| [`flyte.storage.ABFS`](../packages/flyte.storage/abfs) |Any Azure Blob Storage specific configuration. |
| [`flyte.storage.GCS`](../packages/flyte.storage/gcs) |Any GCS specific configuration. |
| [`flyte.storage.S3`](../packages/flyte.storage/s3) |S3 specific configuration. |
| [`flyte.storage.Storage`](../packages/flyte.storage/storage) |Data storage configuration that applies across any provider. |
| [`flyte.syncify.Syncify`](../packages/flyte.syncify/syncify) |A decorator to convert asynchronous functions or methods into synchronous ones. |
| [`flyte.types.FlytePickle`](../packages/flyte.types/flytepickle) |This type is only used by flytekit internally. |
| [`flyte.types.TypeEngine`](../packages/flyte.types/typeengine) |Core Extensible TypeEngine of Flytekit. |
| [`flyte.types.TypeTransformer`](../packages/flyte.types/typetransformer) |Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
| [`flyte.types.TypeTransformerFailedError`](../packages/flyte.types/typetransformerfailederror) |Inappropriate argument type. |
# Protocols
| Protocol | Description |
|-|-|
| [`flyte.CachePolicy`](../packages/flyte/cachepolicy) |Base class for protocol classes. |
| [`flyte.types.Renderable`](../packages/flyte.types/renderable) |Base class for protocol classes. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages ===
# Packages
| Package | Description |
|-|-|
| **Flyte SDK > Packages > flyte** | Flyte SDK for authoring compound AI applications, services and workflows. |
| **Flyte SDK > Packages > flyte.app** | |
| **Flyte SDK > Packages > flyte.app.extras** | |
| **Flyte SDK > Packages > flyte.config** | |
| **Flyte SDK > Packages > flyte.errors** | Exceptions raised by Union. |
| **Flyte SDK > Packages > flyte.extend** | |
| **Flyte SDK > Packages > flyte.extras** | |
| **Flyte SDK > Packages > flyte.git** | |
| **Flyte SDK > Packages > flyte.io** | ## IO data types. |
| **Flyte SDK > Packages > flyte.models** | |
| **Flyte SDK > Packages > flyte.prefetch** | Prefetch utilities for Flyte. |
| **Flyte SDK > Packages > flyte.remote** | Remote Entities that are accessible from the Union Server once deployed or created. |
| **Flyte SDK > Packages > flyte.report** | |
| **Flyte SDK > Packages > flyte.storage** | |
| **Flyte SDK > Packages > flyte.syncify** | # Syncify Module. |
| **Flyte SDK > Packages > flyte.types** | # Flyte Type System. |
## Subpages
- **Flyte SDK > Packages > flyte**
- **Flyte SDK > Packages > flyte.app**
- **Flyte SDK > Packages > flyte.app.extras**
- **Flyte SDK > Packages > flyte.config**
- **Flyte SDK > Packages > flyte.errors**
- **Flyte SDK > Packages > flyte.extend**
- **Flyte SDK > Packages > flyte.extras**
- **Flyte SDK > Packages > flyte.git**
- **Flyte SDK > Packages > flyte.io**
- **Flyte SDK > Packages > flyte.models**
- **Flyte SDK > Packages > flyte.prefetch**
- **Flyte SDK > Packages > flyte.remote**
- **Flyte SDK > Packages > flyte.report**
- **Flyte SDK > Packages > flyte.storage**
- **Flyte SDK > Packages > flyte.syncify**
- **Flyte SDK > Packages > flyte.types**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte ===
# flyte
Flyte SDK for authoring compound AI applications, services and workflows.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `Cache`** | Cache configuration for a task. |
| **Flyte SDK > Packages > flyte > `Cron`** | This class defines a Cron automation that can be associated with a Trigger in Flyte. |
| **Flyte SDK > Packages > flyte > `Device`** | Represents a device type, its quantity and partition if applicable. |
| **Flyte SDK > Packages > flyte > `Environment`** | |
| **Flyte SDK > Packages > flyte > `FixedRate`** | This class defines a FixedRate automation that can be associated with a Trigger in Flyte. |
| **Flyte SDK > Packages > flyte > `Image`** | This is a representation of Container Images, which can be used to create layered images programmatically. |
| **Flyte SDK > Packages > flyte > `PodTemplate`** | Custom PodTemplate specification for a Task. |
| **Flyte SDK > Packages > flyte > `Resources`** | Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| **Flyte SDK > Packages > flyte > `RetryStrategy`** | Retry strategy for the task or task environment. |
| **Flyte SDK > Packages > flyte > `ReusePolicy`** | ReusePolicy can be used to configure a task to reuse the environment. |
| **Flyte SDK > Packages > flyte > `Secret`** | Secrets are used to inject sensitive information into tasks or image build context. |
| **Flyte SDK > Packages > flyte > `TaskEnvironment`** | Environment class to define a new environment for a set of tasks. |
| **Flyte SDK > Packages > flyte > `Timeout`** | Timeout class to define a timeout for a task. |
| **Flyte SDK > Packages > flyte > `Trigger`** | This class defines specification of a Trigger, that can be associated with any Flyte V2 task. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `CachePolicy`** | Base class for protocol classes. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `AMD_GPU()`** | Create an AMD GPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > GPU()** | Create a GPU device instance. |
| **Flyte SDK > Packages > flyte > `HABANA_GAUDI()`** | Create a Habana Gaudi device instance. |
| **Flyte SDK > Packages > flyte > Methods > Neuron()** | Create a Neuron device instance. |
| **Flyte SDK > Packages > flyte > Methods > TPU()** | Create a TPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > build()** | Build an image. |
| **Flyte SDK > Packages > flyte > `build_images()`** | Build the images for the given environments. |
| **Flyte SDK > Packages > flyte > Methods > ctx()** | Returns flyte. |
| **Flyte SDK > Packages > flyte > `current_domain()`** | Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration. |
| **Flyte SDK > Packages > flyte > `custom_context()`** | Synchronous context manager to set input context for tasks spawned within this block. |
| **Flyte SDK > Packages > flyte > Methods > deploy()** | Deploy the given environment or list of environments. |
| **Flyte SDK > Packages > flyte > `get_custom_context()`** | Get the current input context. |
| **Flyte SDK > Packages > flyte > Methods > group()** | Create a new group with the given name. |
| **Flyte SDK > Packages > flyte > Methods > init()** | Initialize the Flyte system with the given configuration. |
| **Flyte SDK > Packages > flyte > `init_from_api_key()`** | Initialize the Flyte system using an API key for authentication. |
| **Flyte SDK > Packages > flyte > `init_from_config()`** | Initialize the Flyte system using a configuration file or Config object. |
| **Flyte SDK > Packages > flyte > `init_in_cluster()`** | |
| **Flyte SDK > Packages > flyte > Methods > map()** | Map a function over the provided arguments with concurrent execution. |
| **Flyte SDK > Packages > flyte > Methods > run()** | Run a task with the given parameters. |
| **Flyte SDK > Packages > flyte > Methods > serve()** | Serve a Flyte app using an AppEnvironment. |
| **Flyte SDK > Packages > flyte > trace()** | A decorator that traces function execution with timing information. |
| **Flyte SDK > Packages > flyte > version()** | Returns the version of the Flyte SDK. |
| **Flyte SDK > Packages > flyte > `with_runcontext()`** | Launch a new run with the given parameters as the context. |
| **Flyte SDK > Packages > flyte > `with_servecontext()`** | Create a serve context with custom configuration. |
### Variables
| Property | Type | Description |
|-|-|-|
| `TimeoutType` | `UnionType` | |
| `TriggerTime` | `_trigger_time` | |
| `__version__` | `str` | |
| `logger` | `Logger` | |
## Methods
#### AMD_GPU()
```python
def AMD_GPU(
device: typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X'],
) -> flyte._resources.Device
```
Create an AMD GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X']` | Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"). :return: Device instance. |
#### GPU()
```python
def GPU(
device: typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000', 'GB10'],
quantity: typing.Literal[1, 2, 3, 4, 5, 6, 7, 8],
partition: typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType],
) -> flyte._resources.Device
```
Create a GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000', 'GB10']` | The type of GPU (e.g., "T4", "A100"). |
| `quantity` | `typing.Literal[1, 2, 3, 4, 5, 6, 7, 8]` | The number of GPUs of this type. |
| `partition` | `typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType]` | The partition of the GPU (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus). :return: Device instance. |
#### HABANA_GAUDI()
```python
def HABANA_GAUDI(
device: typing.Literal['Gaudi1'],
) -> flyte._resources.Device
```
Create a Habana Gaudi device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Gaudi1']` | Device type (e.g., "Gaudi1"). :return: Device instance. |
#### Neuron()
```python
def Neuron(
device: typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u'],
) -> flyte._resources.Device
```
Create a Neuron device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u']` | Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"). |
#### TPU()
```python
def TPU(
device: typing.Literal['V5P', 'V6E'],
partition: typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType],
)
```
Create a TPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['V5P', 'V6E']` | Device type (e.g., "V5P", "V6E"). |
| `partition` | `typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType]` | Partition of the TPU (e.g., "1x1", "2x2", ...). :return: Device instance. |
#### build()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build.aio()`.
```python
def build(
image: Image,
) -> str
```
Build an image. The existing async context will be used.
Example:
```
import flyte
image = flyte.Image("example_image")
if __name__ == "__main__":
asyncio.run(flyte.build.aio(image))
```
| Parameter | Type | Description |
|-|-|-|
| `image` | `Image` | The image(s) to build. :return: The image URI. |
#### build_images()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build_images.aio()`.
```python
def build_images(
envs: Environment,
) -> ImageCache
```
Build the images for the given environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment to build images for. :return: ImageCache containing the built images. |
#### ctx()
```python
def ctx()
```
Returns flyte.models.TaskContext if within a task context, else None
Note: Only use this in task code and not module level.
#### current_domain()
```python
def current_domain()
```
Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration.
This is safe to be used during `deploy`, `run` and within `task` code.
NOTE: This will not work if you deploy a task to a domain and then run it in another domain.
Raises InitializationError if the configuration is not initialized or domain is not set.
:return: The current domain
#### custom_context()
```python
def custom_context(
context: str,
)
```
Synchronous context manager to set input context for tasks spawned within this block.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
ctx = flyte.get_custom_context()
print(ctx)
@env.task
def main():
# context can be passed via a context manager
with flyte.custom_context(project="my-project"):
t1() # will have {'project': 'my-project'} as context
```
| Parameter | Type | Description |
|-|-|-|
| `context` | `str` | Key-value pairs to set as input context |
#### deploy()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await deploy.aio()`.
```python
def deploy(
envs: Environment,
dryrun: bool,
version: str | None,
interactive_mode: bool | None,
copy_style: CopyFiles,
) -> List[Deployment]
```
Deploy the given environment or list of environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment or list of environments to deploy. |
| `dryrun` | `bool` | dryrun mode, if True, the deployment will not be applied to the control plane. |
| `version` | `str \| None` | version of the deployment, if None, the version will be computed from the code bundle. TODO: Support for interactive_mode |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `copy_style` | `CopyFiles` | Copy style to use when running the task :return: Deployment object containing the deployed environments and tasks. |
#### get_custom_context()
```python
def get_custom_context()
```
Get the current input context. This can be used within a task to retrieve
context metadata that was passed to the action.
Context will automatically propagate to sub-actions.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
# context can be retrieved with `get_custom_context`
ctx = flyte.get_custom_context()
print(ctx) # {'project': '...', 'entity': '...'}
```
:return: Dictionary of context key-value pairs
#### group()
```python
def group(
name: str,
)
```
Create a new group with the given name. The method is intended to be used as a context manager.
Example:
```python
@task
async def my_task():
...
with group("my_group"):
t1(x,y) # tasks in this block will be grouped under "my_group"
...
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the group |
#### init()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init.aio()`.
```python
def init(
org: str | None,
project: str | None,
domain: str | None,
root_dir: Path | None,
log_level: int | None,
log_format: LogFormat | None,
endpoint: str | None,
headless: bool,
insecure: bool,
insecure_skip_verify: bool,
ca_cert_file_path: str | None,
auth_type: AuthType,
command: List[str] | None,
proxy_command: List[str] | None,
api_key: str | None,
client_id: str | None,
client_credentials_secret: str | None,
auth_client_config: ClientConfig | None,
rpc_retries: int,
http_proxy_url: str | None,
storage: Storage | None,
batch_size: int,
image_builder: ImageBuildEngine.ImageBuilderType,
images: typing.Dict[str, str] | None,
source_config_path: Optional[Path],
sync_local_sys_paths: bool,
load_plugin_type_transformers: bool,
)
```
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | Optional organization override for the client. Should be set by auth instead. |
| `project` | `str \| None` | Optional project name (not used in this implementation) |
| `domain` | `str \| None` | Optional domain name (not used in this implementation) |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files. This is useful for determining the root directory for the current project, and for locating files like config etc. also use to determine all the code that needs to be copied to the remote location. defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd. |
| `log_level` | `int \| None` | Optional logging level for the logger, default is set using the default initialization policies |
| `log_format` | `LogFormat \| None` | Optional logging format for the logger, default is "console" |
| `endpoint` | `str \| None` | Optional API endpoint URL |
| `headless` | `bool` | Optional Whether to run in headless mode |
| `insecure` | `bool` | insecure flag for the client |
| `insecure_skip_verify` | `bool` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `str \| None` | [optional] str Root Cert to be loaded and used to verify admin |
| `auth_type` | `AuthType` | The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow) |
| `command` | `List[str] \| None` | This command is executed to return a token using an external process |
| `proxy_command` | `List[str] \| None` | This command is executed to return a token for proxy authorization using an external process |
| `api_key` | `str \| None` | Optional API key for authentication |
| `client_id` | `str \| None` | This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. |
| `client_credentials_secret` | `str \| None` | Used for service auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the password directly from the environment variable. Note that this is less secure! Please only use this if mounting the secret as a file is impossible |
| `auth_client_config` | `ClientConfig \| None` | Optional client configuration for authentication |
| `rpc_retries` | `int` | [optional] int Number of times to retry the platform calls |
| `http_proxy_url` | `str \| None` | [optional] HTTP Proxy to be used for OAuth requests |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000, so limit larger than batch_size will be split into multiple requests. |
| `image_builder` | `ImageBuildEngine.ImageBuilderType` | Optional image builder configuration, if not provided, the default image builder will be used. |
| `images` | `typing.Dict[str, str] \| None` | Optional dict of images that can be used by referencing the image name. |
| `source_config_path` | `Optional[Path]` | Optional path to the source configuration file (This is only used for documentation) |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). |
| `load_plugin_type_transformers` | `bool` | If enabled (default True), load the type transformer plugins registered under the "flyte.plugins.types" entry point group. :return: None |
#### init_from_api_key()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_from_api_key.aio()`.
```python
def init_from_api_key(
endpoint: str,
api_key: str | None,
project: str | None,
domain: str | None,
root_dir: Path | None,
log_level: int | None,
log_format: LogFormat | None,
storage: Storage | None,
batch_size: int,
image_builder: ImageBuildEngine.ImageBuilderType,
images: typing.Dict[str, str] | None,
sync_local_sys_paths: bool,
)
```
Initialize the Flyte system using an API key for authentication. This is a convenience
method for API key-based authentication. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str` | The Flyte API endpoint URL |
| `api_key` | `str \| None` | Optional API key for authentication. If None, reads from FLYTE_API_KEY environment variable. |
| `project` | `str \| None` | Optional project name |
| `domain` | `str \| None` | Optional domain name |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files. defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd. |
| `log_level` | `int \| None` | Optional logging level for the logger |
| `log_format` | `LogFormat \| None` | Optional logging format for the logger, default is "console" |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000 |
| `image_builder` | `ImageBuildEngine.ImageBuilderType` | Optional image builder configuration |
| `images` | `typing.Dict[str, str] \| None` | Optional dict of images that can be used by referencing the image name |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True) :return: None |
#### init_from_config()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_from_config.aio()`.
```python
def init_from_config(
path_or_config: str | Path | Config | None,
root_dir: Path | None,
log_level: int | None,
log_format: LogFormat,
project: str | None,
domain: str | None,
storage: Storage | None,
batch_size: int,
image_builder: ImageBuildEngine.ImageBuilderType | None,
images: tuple[str, ...] | None,
sync_local_sys_paths: bool,
)
```
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
other Flyte remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `path_or_config` | `str \| Path \| Config \| None` | Path to the configuration file or Config object |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files like config etc. For example if one uses the copy-style=="all", it is essential to determine the root directory for the current project. If not provided, it defaults to the editable install directory or if not available, the current working directory. |
| `log_level` | `int \| None` | Optional logging level for the framework logger, default is set using the default initialization policies |
| `log_format` | `LogFormat` | Optional logging format for the logger, default is "console" |
| `project` | `str \| None` | Project name, this will override any project names in the configuration file |
| `domain` | `str \| None` | Domain name, this will override any domain names in the configuration file |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000 |
| `image_builder` | `ImageBuildEngine.ImageBuilderType \| None` | Optional image builder configuration, if provided, will override any defaults set in the configuration. :return: None |
| `images` | `tuple[str, ...] \| None` | List of image strings in format "imagename=imageuri" or just "imageuri". |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). |
#### init_in_cluster()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_in_cluster.aio()`.
```python
def init_in_cluster(
org: str | None,
project: str | None,
domain: str | None,
api_key: str | None,
endpoint: str | None,
insecure: bool,
) -> dict[str, typing.Any]
```
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `api_key` | `str \| None` | |
| `endpoint` | `str \| None` | |
| `insecure` | `bool` | |
#### map()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flyte.map.aio()`.
```python
def map(
func: typing.Union[flyte._task.AsyncFunctionTaskTemplate[~P, ~R, ~F], functools.partial[~R]],
args: *args,
group_name: str | None,
concurrency: int,
return_exceptions: bool,
) -> typing.Iterator[typing.Union[~R, Exception]]
```
Map a function over the provided arguments with concurrent execution.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Union[flyte._task.AsyncFunctionTaskTemplate[~P, ~R, ~F], functools.partial[~R]]` | The async function to map. |
| `args` | `*args` | Positional arguments to pass to the function (iterables that will be zipped). |
| `group_name` | `str \| None` | The name of the group for the mapped tasks. |
| `concurrency` | `int` | The maximum number of concurrent tasks to run. If 0, run all tasks concurrently. |
| `return_exceptions` | `bool` | If True, yield exceptions instead of raising them. :return: AsyncIterator yielding results in order. |
#### run()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await run.aio()`.
```python
def run(
task: TaskTemplate[P, R, F],
args: *args,
kwargs: **kwargs,
) -> Run
```
Run a task with the given parameters
| Parameter | Type | Description |
|-|-|-|
| `task` | `TaskTemplate[P, R, F]` | task to run |
| `args` | `*args` | args to pass to the task |
| `kwargs` | `**kwargs` | kwargs to pass to the task :return: Run \| Result of the task |
#### serve()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await serve.aio()`.
```python
def serve(
app_env: 'AppEnvironment',
) -> 'App'
```
Serve a Flyte app using an AppEnvironment.
This is the simple, direct way to serve an app. For more control over
deployment settings (env vars, cluster pool, etc.), use with_servecontext().
Example:
```python
import flyte
from flyte.app.extras import FastAPIAppEnvironment
env = FastAPIAppEnvironment(name="my-app", ...)
# Simple serve
app = flyte.serve(env)
print(f"App URL: {app.url}")
```
| Parameter | Type | Description |
|-|-|-|
| `app_env` | `'AppEnvironment'` | The app environment to serve |
#### trace()
```python
def trace(
func: typing.Callable[..., ~T],
) -> typing.Callable[..., ~T]
```
A decorator that traces function execution with timing information.
Works with regular functions, async functions, and async generators/iterators.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Callable[..., ~T]` | |
#### version()
```python
def version()
```
Returns the version of the Flyte SDK.
#### with_runcontext()
```python
def with_runcontext(
mode: Mode | None,
name: Optional[str],
service_account: Optional[str],
version: Optional[str],
copy_style: CopyFiles,
dry_run: bool,
copy_bundle_to: pathlib.Path | None,
interactive_mode: bool | None,
raw_data_path: str | None,
run_base_dir: str | None,
overwrite_cache: bool,
project: str | None,
domain: str | None,
env_vars: Dict[str, str] | None,
labels: Dict[str, str] | None,
annotations: Dict[str, str] | None,
interruptible: bool | None,
log_level: int | None,
log_format: LogFormat,
disable_run_cache: bool,
queue: Optional[str],
custom_context: Dict[str, str] | None,
cache_lookup_scope: CacheLookupScope,
) -> _Runner
```
Launch a new run with the given parameters as the context.
Example:
```python
import flyte
env = flyte.TaskEnvironment("example")
@env.task
async def example_task(x: int, y: str) -> str:
return f"{x} {y}"
if __name__ == "__main__":
flyte.with_runcontext(name="example_run_id").run(example_task, 1, y="hello")
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Mode \| None` | Optional The mode to use for the run, if not provided, it will be computed from flyte.init |
| `name` | `Optional[str]` | Optional The name to use for the run |
| `service_account` | `Optional[str]` | Optional The service account to use for the run context |
| `version` | `Optional[str]` | Optional The version to use for the run, if not provided, it will be computed from the code bundle |
| `copy_style` | `CopyFiles` | Optional The copy style to use for the run context |
| `dry_run` | `bool` | Optional If true, the run will not be executed, but the bundle will be created |
| `copy_bundle_to` | `pathlib.Path \| None` | When dry_run is True, the bundle will be copied to this location if specified |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `raw_data_path` | `str \| None` | Use this path to store the raw data for the run for local and remote, and can be used to store raw data in specific locations. |
| `run_base_dir` | `str \| None` | Optional The base directory to use for the run. This is used to store the metadata for the run, that is passed between tasks. |
| `overwrite_cache` | `bool` | Optional If true, the cache will be overwritten for the run |
| `project` | `str \| None` | Optional The project to use for the run |
| `domain` | `str \| None` | Optional The domain to use for the run |
| `env_vars` | `Dict[str, str] \| None` | Optional Environment variables to set for the run |
| `labels` | `Dict[str, str] \| None` | Optional Labels to set for the run |
| `annotations` | `Dict[str, str] \| None` | Optional Annotations to set for the run |
| `interruptible` | `bool \| None` | Optional If true, the run can be scheduled on interruptible instances and false implies that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the original setting on all tasks is retained. |
| `log_level` | `int \| None` | Optional Log level to set for the run. If not provided, it will be set to the default log level set using `flyte.init()` |
| `log_format` | `LogFormat` | Optional Log format to set for the run. If not provided, it will be set to the default log format |
| `disable_run_cache` | `bool` | Optional If true, the run cache will be disabled. This is useful for testing purposes. |
| `queue` | `Optional[str]` | Optional The queue to use for the run. This is used to specify the cluster to use for the run. |
| `custom_context` | `Dict[str, str] \| None` | Optional global input context to pass to the task. This will be available via get_custom_context() within the task and will automatically propagate to sub-tasks. Acts as base/default values that can be overridden by context managers in the code. |
| `cache_lookup_scope` | `CacheLookupScope` | Optional Scope to use for the run. This is used to specify the scope to use for cache lookups. If not specified, it will be set to the default scope (global unless overridden at the system level). :return: runner |
#### with_servecontext()
```python
def with_servecontext(
version: Optional[str],
copy_style: CopyFiles,
dry_run: bool,
project: str | None,
domain: str | None,
env_vars: dict[str, str] | None,
input_values: dict[str, dict[str, str | flyte.io.File | flyte.io.Dir]] | None,
cluster_pool: str | None,
log_level: int | None,
log_format: LogFormat,
) -> _Serve
```
Create a serve context with custom configuration.
This function allows you to customize how an app is served, including
overriding environment variables, cluster pool, logging, and other deployment settings.
Example:
```python
import logging
import flyte
from flyte.app.extras import FastAPIAppEnvironment
env = FastAPIAppEnvironment(name="my-app", ...)
# Serve with custom env vars, logging, and cluster pool
app = flyte.with_servecontext(
env_vars={"DATABASE_URL": "postgresql://..."},
log_level=logging.DEBUG,
log_format="json",
cluster_pool="gpu-pool",
project="my-project",
domain="development",
).serve(env)
print(f"App URL: {app.url}")
```
| Parameter | Type | Description |
|-|-|-|
| `version` | `Optional[str]` | Optional version override for the app deployment |
| `copy_style` | `CopyFiles` | |
| `dry_run` | `bool` | |
| `project` | `str \| None` | Optional project override |
| `domain` | `str \| None` | Optional domain override |
| `env_vars` | `dict[str, str] \| None` | Optional environment variables to inject/override in the app container |
| `input_values` | `dict[str, dict[str, str \| flyte.io.File \| flyte.io.Dir]] \| None` | Optional input values to inject/override in the app container. Must be a dictionary that maps app environment names to a dictionary of input names to values. |
| `cluster_pool` | `str \| None` | Optional cluster pool to deploy the app to |
| `log_level` | `int \| None` | Optional log level (e.g., logging.DEBUG, logging.INFO). If not provided, uses init config or default |
| `log_format` | `LogFormat` | |
## Subpages
- [Cache](Cache/)
- [CachePolicy](CachePolicy/)
- [Cron](Cron/)
- [Device](Device/)
- [Environment](Environment/)
- [FixedRate](FixedRate/)
- [Image](Image/)
- [PodTemplate](PodTemplate/)
- [Resources](Resources/)
- [RetryStrategy](RetryStrategy/)
- [ReusePolicy](ReusePolicy/)
- [Secret](Secret/)
- [TaskEnvironment](TaskEnvironment/)
- [Timeout](Timeout/)
- [Trigger](Trigger/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app ===
# flyte.app
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint** | Embed an upstream app's endpoint as an app input. |
| **Flyte SDK > Packages > flyte.app > `AppEnvironment`** | |
| **Flyte SDK > Packages > flyte.app > `Domain`** | Subdomain to use for the domain. |
| **Flyte SDK > Packages > flyte.app > `Input`** | Input for application. |
| **Flyte SDK > Packages > flyte.app > `Link`** | Custom links to add to the app. |
| **Flyte SDK > Packages > flyte.app > `Port`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput** | Use a run's output for app inputs. |
| **Flyte SDK > Packages > flyte.app > `Scaling`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > `get_input()`** | Get inputs for application or endpoint. |
## Methods
#### get_input()
```python
def get_input(
name: str,
) -> str
```
Get inputs for application or endpoint.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
## Subpages
- **Flyte SDK > Packages > flyte.app > AppEndpoint**
- [AppEnvironment](AppEnvironment/)
- [Domain](Domain/)
- [Input](Input/)
- [Link](Link/)
- [Port](Port/)
- **Flyte SDK > Packages > flyte.app > RunOutput**
- [Scaling](Scaling/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/appendpoint ===
# AppEndpoint
**Package:** `flyte.app`
Embed an upstream app's endpoint as an app input.
This enables the declaration of an app input dependency on a the endpoint of
an upstream app, given by a specific app name. This gives the app access to
the upstream app's endpoint as a public or private url.
```python
class AppEndpoint(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `check_type()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > json()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > materialize()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > validate()** | |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### get()
```python
def get()
```
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### materialize()
```python
def materialize()
```
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app/runoutput ===
# RunOutput
**Package:** `flyte.app`
Use a run's output for app inputs.
This enables the declaration of an app input dependency on the output of
a run, given by a specific run name, or a task name and version. If
`task_auto_version == 'latest'`, the latest version of the task will be used.
If `task_auto_version == 'current'`, the version will be derived from the callee
app or task context. To get the latest task run for ephemeral task runs, set
`task_version` and `task_auto_version` should both be set to `None` (which is the default).
Examples:
Get the output of a specific run:
```python
run_output = RunOutput(type="directory", run_name="my-run-123")
```
Get the latest output of an ephemeral task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task")
```
Get the latest output of a deployed task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task", task_auto_version="latest")
```
Get the output of a specific task run:
```python
run_output = RunOutput(type="file", task_name="env.my_task", task_version="xyz")
```
```python
class RunOutput(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > RunOutput > `check_type()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > json()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > materialize()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > validate()** | |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### get()
```python
def get()
```
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### materialize()
```python
def materialize()
```
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras ===
# flyte.app.extras
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment** | |
## Subpages
- **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.app.extras/fastapiappenvironment ===
# FastAPIAppEnvironment
**Package:** `flyte.app.extras`
```python
class FastAPIAppEnvironment(
name: str,
depends_on: List[Environment],
pod_template: Optional[Union[str, PodTemplate]],
description: Optional[str],
secrets: Optional[SecretRequest],
env_vars: Optional[Dict[str, str]],
resources: Optional[Resources],
interruptible: bool,
image: Union[str, Image, Literal['auto']],
port: int | Port,
args: *args,
command: Optional[Union[List[str], str]],
requires_auth: bool,
scaling: Scaling,
domain: Domain | None,
links: List[Link],
include: List[str],
inputs: List[Input],
cluster_pool: str,
type: str,
app: fastapi.FastAPI,
_caller_frame: inspect.FrameInfo | None,
)
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `depends_on` | `List[Environment]` | |
| `pod_template` | `Optional[Union[str, PodTemplate]]` | |
| `description` | `Optional[str]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `env_vars` | `Optional[Dict[str, str]]` | |
| `resources` | `Optional[Resources]` | |
| `interruptible` | `bool` | |
| `image` | `Union[str, Image, Literal['auto']]` | |
| `port` | `int \| Port` | |
| `args` | `*args` | |
| `command` | `Optional[Union[List[str], str]]` | |
| `requires_auth` | `bool` | |
| `scaling` | `Scaling` | |
| `domain` | `Domain \| None` | |
| `links` | `List[Link]` | |
| `include` | `List[str]` | |
| `inputs` | `List[Input]` | |
| `cluster_pool` | `str` | |
| `type` | `str` | |
| `app` | `fastapi.FastAPI` | |
| `_caller_frame` | `inspect.FrameInfo \| None` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `add_dependency()`** | Add a dependency to the environment. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `clone_with()`** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `container_args()`** | Generate the container arguments for running the FastAPI app with uvicorn. |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `container_cmd()`** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `container_command()`** | |
| **Flyte SDK > Packages > flyte.app.extras > FastAPIAppEnvironment > `get_port()`** | |
### add_dependency()
```python
def add_dependency(
env: Environment,
)
```
Add a dependency to the environment.
| Parameter | Type | Description |
|-|-|-|
| `env` | `Environment` | |
### clone_with()
```python
def clone_with(
name: str,
image: Optional[Union[str, Image, Literal['auto']]],
resources: Optional[Resources],
env_vars: Optional[dict[str, str]],
secrets: Optional[SecretRequest],
depends_on: Optional[List[Environment]],
description: Optional[str],
interruptible: Optional[bool],
kwargs: **kwargs,
) -> AppEnvironment
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
| `image` | `Optional[Union[str, Image, Literal['auto']]]` | |
| `resources` | `Optional[Resources]` | |
| `env_vars` | `Optional[dict[str, str]]` | |
| `secrets` | `Optional[SecretRequest]` | |
| `depends_on` | `Optional[List[Environment]]` | |
| `description` | `Optional[str]` | |
| `interruptible` | `Optional[bool]` | |
| `kwargs` | `**kwargs` | |
### container_args()
```python
def container_args(
serialization_context: flyte.models.SerializationContext,
) -> list[str]
```
Generate the container arguments for running the FastAPI app with uvicorn.
Returns:
A list of command arguments in the format:
["uvicorn", "<module_name>:<app_var_name>", "--port", "<port>"]
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `flyte.models.SerializationContext` | |
### container_cmd()
```python
def container_cmd(
serialize_context: SerializationContext,
input_overrides: list[Input] | None,
) -> List[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialize_context` | `SerializationContext` | |
| `input_overrides` | `list[Input] \| None` | |
### container_command()
```python
def container_command(
serialization_context: flyte.models.SerializationContext,
) -> list[str]
```
| Parameter | Type | Description |
|-|-|-|
| `serialization_context` | `flyte.models.SerializationContext` | |
### get_port()
```python
def get_port()
```
## Properties
| Property | Type | Description |
|-|-|-|
| `endpoint` | `None` | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.config ===
# flyte.config
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > `Config`** | This the parent configuration object and holds all the underlying configuration object types. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > Methods > auto()** | Automatically constructs the Config Object. |
| **Flyte SDK > Packages > flyte.config > `set_if_exists()`** | Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set. |
## Methods
#### auto()
```python
def auto(
config_file: typing.Union[str, pathlib.Path, ConfigFile, None],
) -> Config
```
Automatically constructs the Config Object. The order of precedence is as follows
1. If specified, read the config from the provided file path.
2. If not specified, the config file is searched in the default locations.
a. ./config.yaml if it exists (current working directory)
b. ./.flyte/config.yaml if it exists (current working directory)
c. <git_root>/.flyte/config.yaml if it exists
d. `UCTL_CONFIG` environment variable
e. `FLYTECTL_CONFIG` environment variable
f. ~/.union/config.yaml if it exists
g. ~/.flyte/config.yaml if it exists
3. If any value is not found in the config file, the default value is used.
4. For any value there are environment variables that match the config variable names, those will override
| Parameter | Type | Description |
|-|-|-|
| `config_file` | `typing.Union[str, pathlib.Path, ConfigFile, None]` | file path to read the config from, if not specified default locations are searched :return: Config |
#### set_if_exists()
```python
def set_if_exists(
d: dict,
k: str,
val: typing.Any,
) -> dict
```
Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
and return the updated dictionary.
| Parameter | Type | Description |
|-|-|-|
| `d` | `dict` | |
| `k` | `str` | |
| `val` | `typing.Any` | |
## Subpages
- [Config](Config/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.errors ===
# flyte.errors
Exceptions raised by Union.
These errors are raised when the underlying task execution fails, either because of a user error, system error or an
unknown error.
## Directory
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > `ActionNotFoundError`** | This error is raised when the user tries to access an action that does not exist. |
| **Flyte SDK > Packages > flyte.errors > `BaseRuntimeError`** | Base class for all Union runtime errors. |
| **Flyte SDK > Packages > flyte.errors > `CustomError`** | This error is raised when the user raises a custom error. |
| **Flyte SDK > Packages > flyte.errors > `DeploymentError`** | This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| **Flyte SDK > Packages > flyte.errors > `ImageBuildError`** | This error is raised when the image build fails. |
| **Flyte SDK > Packages > flyte.errors > `ImagePullBackOffError`** | This error is raised when the image cannot be pulled. |
| **Flyte SDK > Packages > flyte.errors > `InitializationError`** | This error is raised when the Union system is tried to access without being initialized. |
| **Flyte SDK > Packages > flyte.errors > `InlineIOMaxBytesBreached`** | This error is raised when the inline IO max bytes limit is breached. |
| **Flyte SDK > Packages > flyte.errors > `InvalidImageNameError`** | This error is raised when the image name is invalid. |
| **Flyte SDK > Packages > flyte.errors > `LogsNotYetAvailableError`** | This error is raised when the logs are not yet available for a task. |
| **Flyte SDK > Packages > flyte.errors > `ModuleLoadError`** | This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| **Flyte SDK > Packages > flyte.errors > `NotInTaskContextError`** | This error is raised when the user tries to access the task context outside of a task. |
| **Flyte SDK > Packages > flyte.errors > `OOMError`** | This error is raised when the underlying task execution fails because of an out-of-memory error. |
| **Flyte SDK > Packages > flyte.errors > `OnlyAsyncIOSupportedError`** | This error is raised when the user tries to use sync IO in an async task. |
| **Flyte SDK > Packages > flyte.errors > `PrimaryContainerNotFoundError`** | This error is raised when the primary container is not found. |
| **Flyte SDK > Packages > flyte.errors > `ReferenceTaskError`** | This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > `RetriesExhaustedError`** | This error is raised when the underlying task execution fails after all retries have been exhausted. |
| **Flyte SDK > Packages > flyte.errors > `RunAbortedError`** | This error is raised when the run is aborted by the user. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeDataValidationError`** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeSystemError`** | This error is raised when the underlying task execution fails because of a system error. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeUnknownError`** | This error is raised when the underlying task execution fails because of an unknown error. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeUserError`** | This error is raised when the underlying task execution fails because of an error in the user's code. |
| **Flyte SDK > Packages > flyte.errors > `SlowDownError`** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > `TaskInterruptedError`** | This error is raised when the underlying task execution is interrupted. |
| **Flyte SDK > Packages > flyte.errors > `TaskTimeoutError`** | This error is raised when the underlying task execution runs for longer than the specified timeout. |
| **Flyte SDK > Packages > flyte.errors > `UnionRpcError`** | This error is raised when communication with the Union server fails. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > `silence_grpc_polling_error()`** | Suppress specific gRPC polling errors in the event loop. |
## Methods
#### silence_grpc_polling_error()
```python
def silence_grpc_polling_error(
loop,
context,
)
```
Suppress specific gRPC polling errors in the event loop.
| Parameter | Type | Description |
|-|-|-|
| `loop` | | |
| `context` | | |
## Subpages
- [ActionNotFoundError](ActionNotFoundError/)
- [BaseRuntimeError](BaseRuntimeError/)
- [CustomError](CustomError/)
- [DeploymentError](DeploymentError/)
- [ImageBuildError](ImageBuildError/)
- [ImagePullBackOffError](ImagePullBackOffError/)
- [InitializationError](InitializationError/)
- [InlineIOMaxBytesBreached](InlineIOMaxBytesBreached/)
- [InvalidImageNameError](InvalidImageNameError/)
- [LogsNotYetAvailableError](LogsNotYetAvailableError/)
- [ModuleLoadError](ModuleLoadError/)
- [NotInTaskContextError](NotInTaskContextError/)
- [OnlyAsyncIOSupportedError](OnlyAsyncIOSupportedError/)
- [OOMError](OOMError/)
- [PrimaryContainerNotFoundError](PrimaryContainerNotFoundError/)
- [ReferenceTaskError](ReferenceTaskError/)
- [RetriesExhaustedError](RetriesExhaustedError/)
- [RunAbortedError](RunAbortedError/)
- [RuntimeDataValidationError](RuntimeDataValidationError/)
- [RuntimeSystemError](RuntimeSystemError/)
- [RuntimeUnknownError](RuntimeUnknownError/)
- [RuntimeUserError](RuntimeUserError/)
- [SlowDownError](SlowDownError/)
- [TaskInterruptedError](TaskInterruptedError/)
- [TaskTimeoutError](TaskTimeoutError/)
- [UnionRpcError](UnionRpcError/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extend ===
# flyte.extend
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > `AsyncFunctionTaskTemplate`** | A task template that wraps an asynchronous functions. |
| **Flyte SDK > Packages > flyte.extend > `ImageBuildEngine`** | ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| **Flyte SDK > Packages > flyte.extend > `TaskTemplate`** | Task template is a template for a task that can be executed. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > `download_code_bundle()`** | Downloads the code bundle if it is not already downloaded. |
| **Flyte SDK > Packages > flyte.extend > `get_proto_resources()`** | Get main resources IDL representation from the resources object. |
| **Flyte SDK > Packages > flyte.extend > `is_initialized()`** | Check if the system has been initialized. |
| **Flyte SDK > Packages > flyte.extend > `pod_spec_from_resources()`** | |
### Variables
| Property | Type | Description |
|-|-|-|
| `PRIMARY_CONTAINER_DEFAULT_NAME` | `str` | |
| `TaskPluginRegistry` | `_Registry` | |
## Methods
#### download_code_bundle()
```python
def download_code_bundle(
code_bundle: flyte.models.CodeBundle,
) -> flyte.models.CodeBundle
```
Downloads the code bundle if it is not already downloaded.
| Parameter | Type | Description |
|-|-|-|
| `code_bundle` | `flyte.models.CodeBundle` | The code bundle to download. :return: The code bundle with the downloaded path. |
#### get_proto_resources()
```python
def get_proto_resources(
resources: flyte._resources.Resources | None,
) -> typing.Optional[flyteidl2.core.tasks_pb2.Resources]
```
Get main resources IDL representation from the resources object
| Parameter | Type | Description |
|-|-|-|
| `resources` | `flyte._resources.Resources \| None` | User facing Resources object containing potentially both requests and limits :return: The given resources as requests and limits |
#### is_initialized()
```python
def is_initialized()
```
Check if the system has been initialized.
:return: True if initialized, False otherwise
#### pod_spec_from_resources()
```python
def pod_spec_from_resources(
primary_container_name: str,
requests: typing.Optional[flyte._resources.Resources],
limits: typing.Optional[flyte._resources.Resources],
k8s_gpu_resource_key: str,
) -> V1PodSpec
```
| Parameter | Type | Description |
|-|-|-|
| `primary_container_name` | `str` | |
| `requests` | `typing.Optional[flyte._resources.Resources]` | |
| `limits` | `typing.Optional[flyte._resources.Resources]` | |
| `k8s_gpu_resource_key` | `str` | |
## Subpages
- [AsyncFunctionTaskTemplate](AsyncFunctionTaskTemplate/)
- [ImageBuildEngine](ImageBuildEngine/)
- [TaskTemplate](TaskTemplate/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.extras ===
# flyte.extras
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > `ContainerTask`** | This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
## Subpages
- [ContainerTask](ContainerTask/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.git ===
# flyte.git
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus** | A class representing the status of a git repository. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > `config_from_root()`** | Get the config file from the git root directory. |
## Methods
#### config_from_root()
```python
def config_from_root(
path: pathlib._local.Path | str,
) -> flyte.config._config.Config | None
```
Get the config file from the git root directory.
By default, the config file is expected to be in `.flyte/config.yaml` in the git root directory.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib._local.Path \| str` | Path to the config file relative to git root directory (default :return: Config object if found, None otherwise |
## Subpages
- **Flyte SDK > Packages > flyte.git > GitStatus**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.git/gitstatus ===
# GitStatus
**Package:** `flyte.git`
A class representing the status of a git repository.
```python
class GitStatus(
is_valid: bool,
is_tree_clean: bool,
remote_url: str,
repo_dir: pathlib._local.Path,
commit_sha: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `is_valid` | `bool` | Whether git repository is valid |
| `is_tree_clean` | `bool` | Whether working tree is clean |
| `remote_url` | `str` | Remote URL in HTTPS format |
| `repo_dir` | `pathlib._local.Path` | Repository root directory |
| `commit_sha` | `str` | Current commit SHA |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus > `build_url()`** | Build a git URL for the given path. |
| **Flyte SDK > Packages > flyte.git > GitStatus > `from_current_repo()`** | Discover git information from the current repository. |
### build_url()
```python
def build_url(
path: pathlib._local.Path | str,
line_number: int,
) -> str
```
Build a git URL for the given path.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib._local.Path \| str` | Path to a file |
| `line_number` | `int` | Line number of the code file :return: Path relative to repo_dir |
### from_current_repo()
```python
def from_current_repo()
```
Discover git information from the current repository.
If Git is not installed or .git does not exist, returns GitStatus with is_valid=False.
:return: GitStatus instance with discovered git information
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.io ===
# flyte.io
## IO data types
This package contains additional data types beyond the primitive data types in python to abstract data flow
of large datasets in Union.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > `DataFrame`** | This is the user facing DataFrame class. |
| **Flyte SDK > Packages > flyte.io > `DataFrameDecoder`** | Helper class that provides a standard way to create an ABC using. |
| **Flyte SDK > Packages > flyte.io > `DataFrameEncoder`** | Helper class that provides a standard way to create an ABC using. |
| **Flyte SDK > Packages > flyte.io > `DataFrameTransformerEngine`** | Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
| **Flyte SDK > Packages > flyte.io > `Dir`** | A generic directory class representing a directory with files of a specified format. |
| **Flyte SDK > Packages > flyte.io > `File`** | A generic file class representing a file with a specified format. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > `lazy_import_dataframe_handler()`** | |
### Variables
| Property | Type | Description |
|-|-|-|
| `PARQUET` | `str` | |
## Methods
#### lazy_import_dataframe_handler()
```python
def lazy_import_dataframe_handler()
```
## Subpages
- [DataFrame](DataFrame/)
- [DataFrameDecoder](DataFrameDecoder/)
- [DataFrameEncoder](DataFrameEncoder/)
- [DataFrameTransformerEngine](DataFrameTransformerEngine/)
- [Dir](Dir/)
- [File](File/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models ===
# flyte.models
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > `ActionID`** | A class representing the ID of an Action, nested within a Run. |
| **Flyte SDK > Packages > flyte.models > ActionPhase** | Represents the execution phase of a Flyte action (run). |
| **Flyte SDK > Packages > flyte.models > `Checkpoints`** | A class representing the checkpoints for a task. |
| **Flyte SDK > Packages > flyte.models > `CodeBundle`** | A class representing a code bundle for a task. |
| **Flyte SDK > Packages > flyte.models > `GroupData`** | |
| **Flyte SDK > Packages > flyte.models > `NativeInterface`** | A class representing the native interface for a task. |
| **Flyte SDK > Packages > flyte.models > `PathRewrite`** | Configuration for rewriting paths during input loading. |
| **Flyte SDK > Packages > flyte.models > `RawDataPath`** | A class representing the raw data path for a task. |
| **Flyte SDK > Packages > flyte.models > `SerializationContext`** | This object holds serialization time contextual information, that can be used when serializing the task and. |
| **Flyte SDK > Packages > flyte.models > `TaskContext`** | A context class to hold the current task executions context. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > `generate_random_name()`** | Generate a random name for the task. |
### Variables
| Property | Type | Description |
|-|-|-|
| `MAX_INLINE_IO_BYTES` | `int` | |
| `TYPE_CHECKING` | `bool` | |
## Methods
#### generate_random_name()
```python
def generate_random_name()
```
Generate a random name for the task. This is used to create unique names for tasks.
TODO we can use unique-namer in the future, for now its just guids
## Subpages
- [ActionID](ActionID/)
- **Flyte SDK > Packages > flyte.models > ActionPhase**
- [Checkpoints](Checkpoints/)
- [CodeBundle](CodeBundle/)
- [GroupData](GroupData/)
- [NativeInterface](NativeInterface/)
- [PathRewrite](PathRewrite/)
- [RawDataPath](RawDataPath/)
- [SerializationContext](SerializationContext/)
- [TaskContext](TaskContext/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.models/actionphase ===
# ActionPhase
**Package:** `flyte.models`
Represents the execution phase of a Flyte action (run).
Actions progress through different phases during their lifecycle:
- Queued: Action is waiting to be scheduled
- Waiting for resources: Action is waiting for compute resources
- Initializing: Action is being initialized
- Running: Action is currently executing
- Succeeded: Action completed successfully
- Failed: Action failed during execution
- Aborted: Action was manually aborted
- Timed out: Action exceeded its timeout limit
This enum can be used for filtering runs and checking execution status.
Example:
>>> from flyte.models import ActionPhase
>>> from flyte.remote import Run
>>>
>>> # Filter runs by phase
>>> runs = Run.listall(in_phase=(ActionPhase.SUCCEEDED, ActionPhase.FAILED))
>>>
>>> # Check if a run succeeded
>>> run = Run.get("my-run")
>>> if run.phase == ActionPhase.SUCCEEDED:
... print("Success!")
>>>
>>> # Check if phase is terminal
>>> if run.phase.is_terminal:
... print("Run completed")
```python
class ActionPhase(
args,
kwds,
)
```
| Parameter | Type | Description |
|-|-|-|
| `args` | `*args` | |
| `kwds` | | |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch ===
# flyte.prefetch
Prefetch utilities for Flyte.
This module provides functionality to prefetch various artifacts from remote registries,
such as HuggingFace models.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo** | Information about a HuggingFace model to store. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig** | Configuration for model sharding. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo** | Information about a stored model. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs** | Arguments for sharding a model using vLLM. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > `hf_model()`** | Store a HuggingFace model to remote storage. |
## Methods
#### hf_model()
```python
def hf_model(
repo: str,
raw_data_path: str | None,
artifact_name: str | None,
architecture: str | None,
task: str,
modality: tuple[str, ...],
serial_format: str | None,
model_type: str | None,
short_description: str | None,
shard_config: ShardConfig | None,
hf_token_key: str,
resources: Resources,
force: int,
) -> Run
```
Store a HuggingFace model to remote storage.
This function downloads a model from the HuggingFace Hub and prefetches it to
remote storage. It supports optional sharding using vLLM for large models.
The prefetch behavior follows this priority:
1. If the model isn't being sharded, stream files directly to remote storage.
2. If streaming fails, fall back to downloading a snapshot and uploading.
3. If sharding is configured, download locally, shard with vLLM, then upload.
Example usage:
```python
import flyte
flyte.init(endpoint="my-flyte-endpoint")
# Store a model without sharding
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN",
)
run.wait()
# Prefetch and shard a model
from flyte.prefetch import ShardConfig, VLLMShardArgs
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
shard_config=ShardConfig(
engine="vllm",
args=VLLMShardArgs(tensor_parallel_size=8),
),
accelerator="A100:8",
hf_token_key="HF_TOKEN",
)
run.wait()
```
| Parameter | Type | Description |
|-|-|-|
| `repo` | `str` | The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf'). |
| `raw_data_path` | `str \| None` | |
| `artifact_name` | `str \| None` | Optional name for the stored artifact. If not provided, the repo name will be used (with '.' replaced by '-'). |
| `architecture` | `str \| None` | Model architecture from HuggingFace config.json. |
| `task` | `str` | Model task (e.g., 'generate', 'classify', 'embed'). Default |
| `modality` | `tuple[str, ...]` | Modalities supported by the model. Default |
| `serial_format` | `str \| None` | Model serialization format (e.g., 'safetensors', 'onnx'). |
| `model_type` | `str \| None` | Model type (e.g., 'transformer', 'custom'). |
| `short_description` | `str \| None` | Short description of the model. |
| `shard_config` | `ShardConfig \| None` | Optional configuration for model sharding with vLLM. |
| `hf_token_key` | `str` | Name of the secret containing the HuggingFace token. Default |
| `resources` | `Resources` | |
| `force` | `int` | Force re-prefetch. Increment to force a new prefetch. Default :return: A Run object representing the prefetch task execution. |
## Subpages
- **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo**
- **Flyte SDK > Packages > flyte.prefetch > ShardConfig**
- **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo**
- **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs**
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/huggingfacemodelinfo ===
# HuggingFaceModelInfo
**Package:** `flyte.prefetch`
Information about a HuggingFace model to store.
```python
class HuggingFaceModelInfo(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > json()** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > Methods > validate()** | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > HuggingFaceModelInfo > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/shardconfig ===
# ShardConfig
**Package:** `flyte.prefetch`
Configuration for model sharding.
```python
class ShardConfig(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > json()** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.prefetch > ShardConfig > Methods > validate()** | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > ShardConfig > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > ShardConfig > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/storedmodelinfo ===
# StoredModelInfo
**Package:** `flyte.prefetch`
Information about a stored model.
```python
class StoredModelInfo(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > json()** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > Methods > validate()** | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > StoredModelInfo > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.prefetch/vllmshardargs ===
# VLLMShardArgs
**Package:** `flyte.prefetch`
Arguments for sharding a model using vLLM.
```python
class VLLMShardArgs(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `get_vllm_args()`** | Get arguments dict for vLLM LLM constructor. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > json()** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > Methods > validate()** | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### get_vllm_args()
```python
def get_vllm_args(
model_path: str,
) -> dict[str, Any]
```
Get arguments dict for vLLM LLM constructor.
| Parameter | Type | Description |
|-|-|-|
| `model_path` | `str` | |
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.prefetch > VLLMShardArgs > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.remote ===
# flyte.remote
Remote Entities that are accessible from the Union Server once deployed or created.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > `Action`** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionDetails`** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionInputs`** | A class representing the inputs of an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionOutputs`** | A class representing the outputs of an action. |
| **Flyte SDK > Packages > flyte.remote > `App`** | A mixin class that provides a method to convert an object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > `Project`** | A class representing a project in the Union API. |
| **Flyte SDK > Packages > flyte.remote > `Run`** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > `RunDetails`** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > `Secret`** | |
| **Flyte SDK > Packages > flyte.remote > `Task`** | |
| **Flyte SDK > Packages > flyte.remote > `TaskDetails`** | |
| **Flyte SDK > Packages > flyte.remote > `Trigger`** | |
| **Flyte SDK > Packages > flyte.remote > `User`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > `create_channel()`** | Creates a new gRPC channel with appropriate authentication interceptors. |
| **Flyte SDK > Packages > flyte.remote > `upload_dir()`** | Uploads a directory to a remote location and returns the remote URI. |
| **Flyte SDK > Packages > flyte.remote > `upload_file()`** | Uploads a file to a remote location and returns the remote URI. |
## Methods
#### create_channel()
```python
def create_channel(
endpoint: str | None,
api_key: str | None,
insecure: typing.Optional[bool],
insecure_skip_verify: typing.Optional[bool],
ca_cert_file_path: typing.Optional[str],
ssl_credentials: typing.Optional[ssl_channel_credentials],
grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]],
compression: typing.Optional[grpc.Compression],
http_session: httpx.AsyncClient | None,
proxy_command: typing.Optional[typing.List[str]],
kwargs,
) -> grpc.aio._base_channel.Channel
```
Creates a new gRPC channel with appropriate authentication interceptors.
This function creates either a secure or insecure gRPC channel based on the provided parameters,
and adds authentication interceptors to the channel. If SSL credentials are not provided,
they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
The function is async because it may need to read certificate files asynchronously
and create authentication interceptors that perform async operations.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str \| None` | The endpoint URL for the gRPC channel |
| `api_key` | `str \| None` | API key for authentication; if provided, it will be used to detect the endpoint and credentials. |
| `insecure` | `typing.Optional[bool]` | Whether to use an insecure channel (no SSL) |
| `insecure_skip_verify` | `typing.Optional[bool]` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `typing.Optional[str]` | Path to CA certificate file for SSL verification |
| `ssl_credentials` | `typing.Optional[ssl_channel_credentials]` | Pre-configured SSL credentials for the channel |
| `grpc_options` | `typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]]` | Additional gRPC channel options |
| `compression` | `typing.Optional[grpc.Compression]` | Compression method for the channel |
| `http_session` | `httpx.AsyncClient \| None` | Pre-configured HTTP session to use for requests |
| `proxy_command` | `typing.Optional[typing.List[str]]` | List of strings for proxy command configuration |
| `kwargs` | `**kwargs` | Additional arguments passed to various functions - For grpc.aio.insecure_channel/secure_channel: - root_certificates: Root certificates for SSL credentials - private_key: Private key for SSL credentials - certificate_chain: Certificate chain for SSL credentials - options: gRPC channel options - compression: gRPC compression method - For proxy configuration: - proxy_env: Dict of environment variables for proxy - proxy_timeout: Timeout for proxy connection - For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors): - auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow") - command: Command to execute for ExternalCommand authentication - client_id: Client ID for ClientSecret authentication - client_secret: Client secret for ClientSecret authentication - client_credentials_secret: Client secret for ClientSecret authentication (alias) - scopes: List of scopes to request during authentication - audience: Audience for the token - http_proxy_url: HTTP proxy URL - verify: Whether to verify SSL certificates - ca_cert_path: Optional path to CA certificate file - header_key: Header key to use for authentication - redirect_uri: OAuth2 redirect URI for PKCE authentication - add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token request - request_auth_code_params: Parameters to add to login URI opened in browser - request_access_token_params: Parameters to add when exchanging auth code for access token - refresh_access_token_params: Parameters to add when refreshing access token :return: grpc.aio.Channel with authentication interceptors configured |
#### upload_dir()
```python
def upload_dir(
dir_path: pathlib._local.Path,
verify: bool,
) -> str
```
Uploads a directory to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `dir_path` | `pathlib._local.Path` | The directory path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. :return: The remote URI of the uploaded directory. |
#### upload_file()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await upload_file.aio()`.
```python
def upload_file(
fp: pathlib._local.Path,
verify: bool,
) -> typing.Tuple[str, str]
```
Uploads a file to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `fp` | `pathlib._local.Path` | The file path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. :return: A tuple containing the MD5 digest and the remote URI. |
## Subpages
- [Action](Action/)
- [ActionDetails](ActionDetails/)
- [ActionInputs](ActionInputs/)
- [ActionOutputs](ActionOutputs/)
- [App](App/)
- [Project](Project/)
- [Run](Run/)
- [RunDetails](RunDetails/)
- [Secret](Secret/)
- [Task](Task/)
- [TaskDetails](TaskDetails/)
- [Trigger](Trigger/)
- [User](User/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.report ===
# flyte.report
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > `Report`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > `current_report()`** | Get the current report. |
| **Flyte SDK > Packages > flyte.report > Methods > flush()** | Flush the report. |
| **Flyte SDK > Packages > flyte.report > `get_tab()`** | Get a tab by name. |
| **Flyte SDK > Packages > flyte.report > Methods > log()** | Log content to the main tab. |
| **Flyte SDK > Packages > flyte.report > Methods > replace()** | Get the report. |
## Methods
#### current_report()
```python
def current_report()
```
Get the current report. This is a dummy report if not in a task context.
:return: The current report.
#### flush()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flush.aio()`.
```python
def flush()
```
Flush the report.
#### get_tab()
```python
def get_tab(
name: str,
create_if_missing: bool,
) -> flyte.report._report.Tab
```
Get a tab by name. If the tab does not exist, create it.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the tab. |
| `create_if_missing` | `bool` | Whether to create the tab if it does not exist. :return: The tab. |
#### log()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await log.aio()`.
```python
def log(
content: str,
do_flush: bool,
)
```
Log content to the main tab. The content should be a valid HTML string, but not a complete HTML document,
as it will be inserted into a div.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | The content to log. |
| `do_flush` | `bool` | flush the report after logging. |
#### replace()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await replace.aio()`.
```python
def replace(
content: str,
do_flush: bool,
)
```
Get the report. Replaces the content of the main tab.
:return: The report.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | |
| `do_flush` | `bool` | |
## Subpages
- [Report](Report/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.storage ===
# flyte.storage
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > `ABFS`** | Any Azure Blob Storage specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `GCS`** | Any GCS specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `S3`** | S3 specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `Storage`** | Data storage configuration that applies across any provider. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > Methods > exists()** | Check if a path exists. |
| **Flyte SDK > Packages > flyte.storage > `exists_sync()`** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get()** | |
| **Flyte SDK > Packages > flyte.storage > `get_configured_fsspec_kwargs()`** | |
| **Flyte SDK > Packages > flyte.storage > `get_random_local_directory()`** | :return: a random directory. |
| **Flyte SDK > Packages > flyte.storage > `get_random_local_path()`** | Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name. |
| **Flyte SDK > Packages > flyte.storage > `get_stream()`** | Get a stream of data from a remote location. |
| **Flyte SDK > Packages > flyte.storage > `get_underlying_filesystem()`** | |
| **Flyte SDK > Packages > flyte.storage > `is_remote()`** | Let's find a replacement. |
| **Flyte SDK > Packages > flyte.storage > Methods > join()** | Join multiple paths together. |
| **Flyte SDK > Packages > flyte.storage > open()** | Asynchronously open a file and return an async context manager. |
| **Flyte SDK > Packages > flyte.storage > put()** | |
| **Flyte SDK > Packages > flyte.storage > `put_stream()`** | Put a stream of data to a remote location. |
## Methods
#### exists()
```python
def exists(
path: str,
kwargs,
) -> bool
```
Check if a path exists.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to be checked. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. :return: True if the path exists, False otherwise. |
#### exists_sync()
```python
def exists_sync(
path: str,
kwargs,
) -> bool
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `kwargs` | `**kwargs` | |
#### get()
```python
def get(
from_path: str,
to_path: Optional[str | pathlib.Path],
recursive: bool,
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str \| pathlib.Path]` | |
| `recursive` | `bool` | |
| `kwargs` | `**kwargs` | |
#### get_configured_fsspec_kwargs()
```python
def get_configured_fsspec_kwargs(
protocol: typing.Optional[str],
anonymous: bool,
) -> typing.Dict[str, typing.Any]
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
#### get_random_local_directory()
```python
def get_random_local_directory()
```
:return: a random directory
:rtype: pathlib.Path
#### get_random_local_path()
```python
def get_random_local_path(
file_path_or_file_name: pathlib.Path | str | None,
) -> pathlib.Path
```
Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name
| Parameter | Type | Description |
|-|-|-|
| `file_path_or_file_name` | `pathlib.Path \| str \| None` | |
#### get_stream()
```python
def get_stream(
path: str,
chunk_size,
kwargs,
) -> AsyncGenerator[bytes, None]
```
Get a stream of data from a remote location.
This is useful for downloading streaming data from a remote location.
Example usage:
```python
import flyte.storage as storage
async for chunk in storage.get_stream(path="s3://my_bucket/my_file.txt"):
process(chunk)
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to the remote location where the data will be downloaded. |
| `chunk_size` | | Size of each chunk to be read from the file. :return: An async iterator that yields chunks of bytes. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. |
#### get_underlying_filesystem()
```python
def get_underlying_filesystem(
protocol: typing.Optional[str],
anonymous: bool,
path: typing.Optional[str],
kwargs,
) -> fsspec.AbstractFileSystem
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
| `path` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
#### is_remote()
```python
def is_remote(
path: typing.Union[pathlib.Path | str],
) -> bool
```
Let's find a replacement
| Parameter | Type | Description |
|-|-|-|
| `path` | `typing.Union[pathlib.Path \| str]` | |
#### join()
```python
def join(
paths: str,
) -> str
```
Join multiple paths together. This is a wrapper around os.path.join.
# TODO replace with proper join with fsspec root etc
| Parameter | Type | Description |
|-|-|-|
| `paths` | `str` | Paths to be joined. |
#### open()
```python
def open(
path: str,
mode: str,
kwargs,
) -> AsyncReadableFile | AsyncWritableFile
```
Asynchronously open a file and return an async context manager.
This function checks if the underlying filesystem supports obstore bypass.
If it does, it uses obstore to open the file. Otherwise, it falls back to
the standard _open function which uses AsyncFileSystem.
It will raise NotImplementedError if neither obstore nor AsyncFileSystem is supported.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `mode` | `str` | |
| `kwargs` | `**kwargs` | |
#### put()
```python
def put(
from_path: str,
to_path: Optional[str],
recursive: bool,
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str]` | |
| `recursive` | `bool` | |
| `kwargs` | `**kwargs` | |
#### put_stream()
```python
def put_stream(
data_iterable: typing.AsyncIterable[bytes] | bytes,
name: str | None,
to_path: str | None,
kwargs,
) -> str
```
Put a stream of data to a remote location. This is useful for streaming data to a remote location.
Example usage:
```python
import flyte.storage as storage
storage.put_stream(iter([b'hello']), name="my_file.txt")
OR
storage.put_stream(iter([b'hello']), to_path="s3://my_bucket/my_file.txt")
```
| Parameter | Type | Description |
|-|-|-|
| `data_iterable` | `typing.AsyncIterable[bytes] \| bytes` | Iterable of bytes to be streamed. |
| `name` | `str \| None` | Name of the file to be created. If not provided, a random name will be generated. |
| `to_path` | `str \| None` | Path to the remote location where the data will be stored. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. :rtype: str :return: The path to the remote location where the data was stored. |
## Subpages
- [ABFS](ABFS/)
- [GCS](GCS/)
- [S3](S3/)
- [Storage](Storage/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.syncify ===
# flyte.syncify
# Syncify Module
This module provides the `syncify` decorator and the `Syncify` class.
The decorator can be used to convert asynchronous functions or methods into synchronous ones.
This is useful for integrating async code into synchronous contexts.
Every asynchronous function or method wrapped with `syncify` can be called synchronously using the
parenthesis `()` operator, or asynchronously using the `.aio()` method.
Example::
```python
from flyte.syncify import syncify
@syncify
async def async_function(x: str) -> str:
return f"Hello, Async World {x}!"
# now you can call it synchronously
result = async_function("Async World") # Note: no .aio() needed for sync calls
print(result)
# Output: Hello, Async World Async World!
# or call it asynchronously
async def main():
result = await async_function.aio("World") # Note the use of .aio() for async calls
print(result)
```
## Creating a Syncify Instance
```python
from flyte.syncify. import Syncify
syncer = Syncify("my_syncer")
# Now you can use `syncer` to decorate your async functions or methods
```
## How does it work?
The Syncify class wraps asynchronous functions, classmethods, instance methods, and static methods to
provide a synchronous interface. The wrapped methods are always executed in the context of a background loop,
whether they are called synchronously or asynchronously. This allows for seamless integration of async code, as
certain async libraries capture the event loop. An example is grpc.aio, which captures the event loop.
In such a case, the Syncify class ensures that the async function is executed in the context of the background loop.
To use it correctly with grpc.aio, you should wrap every grpc.aio channel creation, and client invocation
with the same `Syncify` instance. This ensures that the async code runs in the correct event loop context.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.syncify > `Syncify`** | A decorator to convert asynchronous functions or methods into synchronous ones. |
## Subpages
- [Syncify](Syncify/)
=== PAGE: https://www.union.ai/docs/v2/flyte/api-reference/flyte-sdk/packages/flyte.types ===
# flyte.types
# Flyte Type System
The Flyte type system provides a way to define, transform, and manipulate types in Flyte workflows.
Since the data flowing through Flyte has to often cross process, container and langauge boundaries, the type system
is designed to be serializable to a universal format that can be understood across different environments. This
universal format is based on Protocol Buffers. The types are called LiteralTypes and the runtime
representation of data is called Literals.
The type system includes:
- **TypeEngine**: The core engine that manages type transformations and serialization. This is the main entry point for
for all the internal type transformations and serialization logic.
- **TypeTransformer**: A class that defines how to transform one type to another. This is extensible
allowing users to define custom types and transformations.
- **Renderable**: An interface for types that can be rendered as HTML, that can be outputted to a flyte.report.
It is always possible to bypass the type system and use the `FlytePickle` type to serialize any python object
into a pickle format. The pickle format is not human-readable, but can be passed between flyte tasks that are
written in python. The Pickled objects cannot be represented in the UI, and may be in-efficient for large datasets.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `FlytePickle`** | This type is only used by flytekit internally. |
| **Flyte SDK > Packages > flyte.types > `TypeEngine`** | Core Extensible TypeEngine of Flytekit. |
| **Flyte SDK > Packages > flyte.types > `TypeTransformer`** | Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `Renderable`** | Base class for protocol classes. |
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `TypeTransformerFailedError`** | Inappropriate argument type. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `guess_interface()`** | Returns the interface of the task with guessed types, as types may not be present in current env. |
| **Flyte SDK > Packages > flyte.types > `literal_string_repr()`** | This method is used to convert a literal map to a string representation. |
## Methods
#### guess_interface()
```python
def guess_interface(
interface: flyteidl2.core.interface_pb2.TypedInterface,
default_inputs: typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]],
) -> flyte.models.NativeInterface
```
Returns the interface of the task with guessed types, as types may not be present in current env.
| Parameter | Type | Description |
|-|-|-|
| `interface` | `flyteidl2.core.interface_pb2.TypedInterface` | |
| `default_inputs` | `typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]]` | |
#### literal_string_repr()
```python
def literal_string_repr(
lm: typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]],
) -> typing.Dict[str, typing.Any]
```
This method is used to convert a literal map to a string representation.
| Parameter | Type | Description |
|-|-|-|
| `lm` | `typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]]` | |
## Subpages
- [FlytePickle](FlytePickle/)
- [Renderable](Renderable/)
- [TypeEngine](TypeEngine/)
- [TypeTransformer](TypeTransformer/)
- [TypeTransformerFailedError](TypeTransformerFailedError/)
=== PAGE: https://www.union.ai/docs/v2/flyte/community ===
# Community
Flyte is an open source project that is built and maintained by a community of contributors.
Union AI is the primary maintainer of Flyte and developer of Union.ai, a closed source commercial product that is built on top of Flyte.
Since the success of Flyte is essential to the success of Union.ai,
the company is dedicated to building and expanding the Flyte open source project and community.
For information on how to get involved and how to keep in touch, see **Joining the community**.
## Contributing to the codebase
The full Flyte codebase is open source and available on GitHub.
If you are interested, see **Contributing code**.
## Contributing to documentation
Union AI maintains and hosts both Flyte and Union documentation at [www.union.ai/docs](/docs/v2/root/).
The two sets of documentation are deeply integrated, as the Union product is built on top of Flyte.
To better maintain both sets of docs, they are hosted in the same repository (but rendered so that you can choose to view one or the other).
Both the Flyte and Union documentation are open source.
Flyte community members and Union customers are both welcome to contribute to the documentation.
If you are interested, see [Contributing documentation and examples](./contributing-docs/_index).
## Subpages
- **Joining the community**
- **Contributing code**
- **Contributing docs and examples**
=== PAGE: https://www.union.ai/docs/v2/flyte/community/joining-the-community ===
# Joining the community
Keeping the lines of communication open is important in growing and maintain the Flyte community.
Please join us on:
[](https://slack.flyte.org)
[](https://github.com/flyteorg/flyte/discussions)
[](https://twitter.com/flyteorg)
[](https://www.linkedin.com/groups/13962256)
## Community sync
1. **When**: First Tuesday of every month, 9:00 AM Pacific Time.
2. **Where**: Live streamed on [YouTube](https://www.youtube.com/@flyteorg/streams) and [LinkedIn](https://www.linkedin.com/company/union-ai/events/).
3. **Watch the recordings**: [here](https://www.youtube.com/live/d81Jd4rfmzw?feature=shared).
4. **Import the public calendar**: [here](https://lists.lfaidata.foundation/g/flyte-announce/ics/12031983/2145304139/feed.ics) to not miss any event.
5. **Want to present?** Fill out [this form](https://tally.so/r/wgN8LM). We're eager to learn from you!
You're welcome to join and learn from other community members sharing their experiences with Flyte or any other technology from the AI ecosystem.
## Contributor's sync
1. **When**: Every 2 weeks on Thursdays. Alternating schedule between 11:00 AM PT and 7:00 AM PT.
2. **Where**: Live on [Zoom](https://zoom-lfx.platform.linuxfoundation.org/meeting/92309721545?password=c93d76a7-801a-47c6-9916-08e38e5a5c1f).
3. **Purpose**: Address questions from new contributors, discuss active initiatives, and RFCs.
4. **Import the public calendar**: [here](https://lists.lfaidata.foundation/g/flyte-announce/ics/12031983/2145304139/feed.ics) to not miss any event.
## Newsletter
[Join the Flyte mailing list](https://lists.lfaidata.foundation/g/flyte-announce/join) to receive the monthly newsletter.
## Slack guidelines
Flyte strives to build and maintain an open, inclusive, productive, and self-governing open source community.
In consequence, we expect all community members to respect the following guidelines:
### Abide by the [LF's Code of Conduct](https://lfprojects.org/policies/code-of-conduct/)
As a Linux Foundation project, we must enforce the rules that govern professional and positive open source communities.
### Avoid using DMs and @mentions
Whenever possible, post your questions and responses in public channels so other community members can benefit from the conversation and outcomes.
Exceptions to this are when you need to share private or sensitive information.
In such a case, the outcome should still be shared publicly.
Limit the use of `@mentions` of other community members to be considerate of notification noise.
### Make use of threads
Threads help us keep conversations contained and organized, reducing the time it takes to give you the support you need.
**Thread best practices:**
- Don't break your question into multiple messages. Put everything in one.
- For long questions, write a few sentences in the first message, and put the rest in a thread.
- If there's a code snippet (more than 5 lines of code), put it inside the thread.
- Avoid using the βAlso send to channelβ feature unless it's really necessary.
- If your question contains multiple questions, make sure to break them into multiple messages, so each could be answered in a separate thread.
### Do not post the same question across multiple channels
If you consider that a question needs to be shared on other channels, ask it once and then indicate explicitly that you're cross-posting.
If you're having a tough time getting the support you need (or aren't sure where to go!), please DM `@David Espejo` or `@Samhita Alla` for support.
### Do not solicit members of our Slack
The Flyte community exists to collaborate with, learn from, and support one another.
It is not a space to pitch your products or services directly to our members via public channels, private channels, or direct messages.
We are excited to have a growing presence from vendors to help answer questions from community members as they may arise, but we have a strict 3-strike policy against solicitation:
- **First occurrence**: We'll give you a friendly but public reminder that the behavior is inappropriate according to our guidelines.
- **Second occurrence**: We'll send you a DM warning that any additional violations will result in removal from the community.
- **Third occurrence**: We'll delete or ban your account.
We reserve the right to ban users without notice if they are clearly spamming our community members.
If you want to promote a product or service, go to the `#shameless-promotion` channel and make sure to follow these rules:
- Don't post more than two promotional posts per week.
- Non-relevant topics aren't allowed.
Messages that don't follow these rules will be deleted.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-code ===
# Contributing code
Thank you for your interest in Flyte!
> [!NOTE]
> This page is part of the Flyte 2 documentation.
> If you are interested in contributing code to for Flyte 1, switch the selector at the top of the page to \*v1\*\*.
## FLyte 2
Flyte 2 is currently in active development.
The Flyte 2 SDK source code is available on [GitHub](https://github.com/flyteorg/flyte-sdk) under the same Apache license as the original Flyte 1.
You are welcome to take a look, [download the package](https://pypi.org/project/flyte/#history) and try running code locally.
Keep in mind that this is still in beta and is a work in progress.
The Flyte 2 backend is not yet available as open source, (but it will be soon!)
To run Flyte 2 code now you can apply for a [beta preview of the Union 2 backend](https://www.union.ai/beta).
When the Flyte 2 backend is released we will roll out a full contributor program just as we have for Flyte 1.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs ===
# Contributing docs and examples
We welcome contributions to the docs and examples for both Flyte and Union.
This section will explain how the docs site works, how to author and build it locally, and how to publish your changes.
## The combined Flyte and Union docs site
As the primary maintainer and contributor of the open-source Flyte project, Union AI is responsible for hosting the Flyte documentation.
Additionally, Union AI is also the company behind the commercial Union.ai product, which is based on Flyte.
Since Flyte and Union.ai share a lot of common functionality, much of the documentation content is common between the two.
However, there are some significant differences between not only Flyte and Union.ai but also among the different Union.ai product offering (Serverless, BYOC, and Self-managed).
To effectively and efficiently maintain the documentation for all of these variants, we employ a single-source-of-truth approach where:
* All content is stored in a single GitHub repository, [`unionai/unionai-docs`](https://github.com/unionai/unionai-docs)
* All content is published on a single website, [`www.union.ai/docs`](/docs/v2/root/).
* The website has a variant selector at the top of the page that lets you choose which variant you want to view:
* Flyte OSS
* Union Serverless
* Union BYOC
* Union Self-managed
* There is also version selector. Currently two versions are available:
* v1 (the original docs for Flyte/Union 1.x)
* v2 (the new docs for Flyte/Union 2.0, which is the one you are currently viewing)
## Versions
The two versions of the docs are stored in separate branches of the GitHub repository:
* [`v1` branch](https://github.com/unionai/unionai-docs/tree/v1) for the v1 docs.
* [`main` branch](https://github.com/unionai/unionai-docs) for the v2 docs.
See **Contributing docs and examples > Versions** for more details.
## Variants
Within each branch the multiple variants are supported by using conditional rendering:
* Each page of content has a `variants` front matter field that specifies which variants the page is applicable to.
* Within each page, rendering logic can be used to include or exclude content based on the selected variant.
The result is that:
* Content that is common to all variants is authored and stored once.
There is no need to keep multiple copies of the same content in-sync.
* Content specific to a variant is conditionally rendered based on the selected variant.
See **Contributing docs and examples > Variants** for more details.
## Both Flyte and Union docs are open source
Since the docs are now combined in one repository, and the Flyte docs are open source, the Union docs are also open source.
All the docs are available for anyone to contribute to: Flyte contributors, Union customers, and Union employees.
If you are a Flyte contributor, you will be contributing docs related to Flyte features and functionality, but in many cases these features and functionality will also be available in Union.
Because the docs site is a single source for all the documentation, when you make changes related to Flyte that are also valid for Union you do so in the same place.
This is by design and is a key feature of the docs site.
## Subpages
- **Contributing docs and examples > Quick start**
- **Contributing docs and examples > Variants**
- **Contributing docs and examples > Versions**
- **Contributing docs and examples > Authoring**
- **Contributing docs and examples > Shortcodes**
- **Contributing docs and examples > Redirects**
- **Contributing docs and examples > API docs**
- **Contributing docs and examples > Publishing**
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/quick-start ===
# Quick start
## Prerequisites
The docs site is built using the [Hugo](https://gohugo.io/) static site generator.
You will need to install it to build the site locally.
See [Hugo Installation](https://gohugo.io/getting-started/installing/).
## Clone the repository
Clone the [`unionai/docs`](https://github.com/unionai/unionai-docs) repository to your local machine.
The content is located in the `content/` folder in the form of Markdown files.
The hierarchy of the files and folders under `content/` directly reflect the URL and navigation structure of the site.
## Live preview
Next, set up the live preview by going to the root of your local repository check-out and copy `hugo.local.toml~sample` to `hugo.local.toml`:
```shell
$ cp hugo.local.toml~sample hugo.local.toml
```
This file contains the configuration for the live preview:
By default, it is set to display the `flyte` variant of the docs site along with enabling the flags `show_inactive`, `highlight_active`, and `highlight_keys` (more about these below)
Now you can start the live preview server by running:
```shell
$ make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the content you will see the changes reflected in the live preview.
## Distribution build
To build the site for distribution, run:
```shell
$ make dist
```
This will build the site locally just as it is built by the Cloudflare CI for production.
You can view the result of the build by running a local server:
```shell
$ make serve
```
This will start a local server at `http://localhost:9000` and serve the contents of the `dist/` folder. You can also specify a port number:
```shell
$ make serve PORT=
```
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/variants ===
# Variants
The docs site supports the ability to show or hide content based of the current variant selection.
There are separate mechanisms for:
* Including or excluding entire pages based on the selected variant.
* Conditional rendering of content within a page based on the selected variant using an if-then-like construct.
* Rendering keywords as variables that change based on the selected variant.
Currently, the docs site supports four variants:
- **Flyte OSS**: The open-source Flyte project.
- **Serverless**: The Union.ai product that is hosted and managed by Union AI.
- **BYOC**: The Union.ai product that is hosted on the customer's infrastructure but managed by Union AI.
- **Self-managed**: The Union.ai product that is hosted and managed by the customer.
Each variant is referenced in the page logic using its respective code name: `flyte`, `serverless`, `byoc`, or `selfmanaged`.
The available set of variants are defined in the `config..toml` files in the root of the repository.
## Variants at the whole-page level
The docs site supports the ability to show or hide entire pages based of the selected variant.
Not all pages are available in all variants because features differ across the variants.
In the public website, if you are on page in one variant, and you change to a different variant, the page will change to the same page in the new variant *if it exists*.
If it does not exist, you will see a message indicating that the page is not available in the selected variant.
In the source Markdown, the presence or absence of a page in a given variant is governed by `variants` field in the front matter parameter of the page.
For example, if you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/docs/content/community/contributing-docs.md), you will see the following front matter:
```markdown
---
title: Platform overview
weight: 1
variants: +flyte +serverless +byoc +selfmanaged
---
```
The `variants` field has the value:
`+flyte +serverless +byoc +selfmanaged`
The `+` indicates that the page is available for the specified variant.
In this case, the page is available for all four variants.
If you wanted to make the page available for only the `flyte` and `serverless` variants, you would change the `variants` field to:
`+flyte +serverless -byoc -selfmanaged`
In [live preview mode](./authoring-core-content#live-preview) with the `show_inactive` flag enabled, you will see all pages in the navigation tree, with the ones unavailable for the current variant grayed out.
As you can see, the `variants` field expects a space-separated list of keywords:
* The code names for the currently variants are, `flyte`, `serverless`, `byoc`, and `selfmanaged`.
* All supported variants must be included explicitly in every `variants` field with a leading `+` or `-`. There is no default behavior.
* The supported variants are configured in the root of the repository in the files named `config..toml`.
## Conditional rendering within a page
Content can also differ *within a page* based on the selected variant.
This is done with conditional rendering using the `{{* variant */>}}` and `{{* key */>}}` [Hugo shortcodes](https://gohugo.io/content-management/shortcodes/).
### {{* variant */>}}
The syntax for the `{{* variant */>}}` shortcode is:
```markdown
{{* variant */>}}
...
{{* /variant */>}}
```
Where `` is a list the code name for the variants you want to show the content for.
Note that the variant construct can only directly contain other shortcode constructs, not plain Markdown.
In the most common case, you will want to use the `{{* markdown */>}}` shortcode (which can contain Markdown) inside the `{{* variant */>}}` shortcode to render Markdown content, like this:
```markdown
{{* variant serverless byoc */>}}
{{* markdown */>}}
This content is only visible in the `serverless` and `byoc` variants.
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
For more details on the `{{* variant */>}}` shortcode, see the **Contributing docs and examples > Shortcodes > Component Library > `{{* variant */>}}`**.
### {{* key */>}}
The syntax for the `{{* key */>}}` shortcode is:
```markdown
{{* key */>}}
```
Where `` is the name of the key you want to render.
For example, if you want to render the product name keyword, you would use:
```markdown
{{* key product_name */>}}
```
The available key names are defined in the [params.key] section of the `hugo.site.toml` configuration file in the root of the repository.
For example the `product_name` used above is defined in that file as
```toml
[params.key.product_name]
flyte = "Flyte"
serverless = "Union.ai"
byoc = "Union.ai"
selfmanaged = "Union.ai"
```
Meaning that in any content that appears in the `flyte` variant of the site `{{* key product_name */>}}` shortcode will be replaced with `Flyte`, and in any content that appears in the `serverless`, `byoc`, or `selfmanaged` variants, it will be replaced with `Union.ai`.
For more details on the `{{* key */>}}` shortcode, see the **Contributing docs and examples > Shortcodes > Component Library > `{{* key */>}}`**
## Full example
Here is full example. If you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/docs/content/community/contributing-docs/variants.md), you will see the following section:
```markdown
> **This text is visible in all variants.**
>
> {{* variant flyte */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `flyte` variant.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
> {{* variant serverless byoc selfmanaged */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `serverless`, `byoc`, and `selfmanaged` variants.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **{{* key product_full_name */>}}**
```
This Markdown source is rendered as:
> **This text is visible in all variants.**
>
> >
>
> **This text is only visible in the `flyte` variant.**
>
>
>
>
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **Flyte OSS**
If you switch between variants with the variant selector at the top of the page, you will see the content change accordingly.
## Adding a new variant
A variant is a term we use to identify a product or major section of the site.
Such variant has a dedicated token that identifies it, and all resources are
tagged to be either included or excluded when the variant is built.
> Adding new variants is a rare event and must be reserved when new products
> or major developments.
>
> If you are thinking adding a new variant is the way
> to go, please double-check with the infra admin to confirm before doing all
> the work below and waste your time.
### Location
When deploying, the variant takes a folder in the root
`https:////`
For example, if we have a variant `acme`, then when built the content goes to:
`https:///acme/`
### Creating a new variant
To create a new variant a few steps are required:
| File | Changes |
| ----------------------- | -------------------------------------------------------------- |
| `hugo.site.toml` | Add to `params.variant_weights` and all `params.key` |
| `hugo.toml` | Add to `params.search` |
| `Makefile` | Add a new `make variant` to `dist` target |
| `.md` | Add either `+` or `-` to all content pages |
| `config..toml` | Create a new file and configure `baseURL` and `params.variant` |
### Testing the new variant
As you develop the new variant, it is recommended to have a `pre-release/` semi-stable
branch to confirm everything is working and the content looks good. It will also allow others
to collaborate by creating PRs against it (`base=pre-release/` instead of `main`)
without trampling on each other and allowing for parallel reviews.
Once the variant branch is correct, you merge that branch into main.
### Building (just) the variant
You can build the production version of the variant,
which will also trigger all the safety checks as well,
by invoking the variant build:
```shell
$ make variant VARIANT=
```
For example:
```shell
make variant VARIANT=serverless
```
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/versions ===
# Versions
In addition to the product variants, the docs site also supports multiple versions of the documentation.
The version selector is located at the top of the page, next to the variant selector.
Versions and variants are independent of each other, with the version being "above" the variant in the URL hierarchy.
The URL for version `v2` of the current page (the one you are one right now) in the Flyte variant is:
`/docs/v2/flyte//community/contributing-docs/versions`
while the URL for version `v1` of the same page is:
`/docs/v1/flyte//community/contributing-docs/versions`
### Versions are branches
The versioning system is based on long-lived Git branches in the `unionai/unionai-docs` GitHub repository:
- The `main` branch contains the latest version of the documentation. Currently, `v2`.
- Other versions of the docs are contained in branches named `vX`, where `X` is the major version number. Currently, there is one other version, `v1`.
## How to create an archive version
An "archive version" is a static snapshot of the site at a given point in time.
It is meant to freeze a specific version of the site for historical purposes,
such as preserving the content and structure of the site at a specific point in time.
### How to create an archive version
1. Create a new branch from `main` named `vX`, e.g. `v3`.
2. Add the version to the `VERSION` field in the `makefile.inc` file, e.g. `VERSION := v3`.
3. Add the version to the `versions` field in the `hugo.ver.toml` file, e.g. `versions = [ "v1", "v2", "v3" ]`.
> [!NOTE]
> **Important:** You must update the `versions` field in **ALL** published and archived versions of the site.
### Publishing an archive version
> [!NOTE]
> This step can only be done by a Union employee.
1. Update the `docs_archive_versions` in the `docs_archive_locals.tf` Terraform file
2. Create a PR for the changes
3. Once the PR is merged, run the production pipeline to activate the new version
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/authoring ===
# Authoring
## Getting started
Content is located in the `content` folder.
To create a new page, simply create a new Markdown file in the appropriate folder and start writing it!
## Target the right branch
Remember that there are two production branches in the docs: `main` and `v1`.
* **For Flyte or Union 1, create a branch off of `v1` and target your pull request to `v1`**
* **For Flyte or Union 2, create a branch off of `main` and target your pull request to `main`**
## Live preview
While editing, you can use Hugo's local live preview capabilities.
Simply execute
```shell
$ make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the preview will update automatically.
See **Contributing docs and examples > Publishing** for how to set up your machine.
## Pull Requests + Site Preview
Pull requests will create a preview build of the site on CloudFlare.
Check the pull request for a dynamic link to the site changes within that PR.
## Page Visibility
This site uses variants, which means different "flavors" of the content.
For a given -age, its variant visibility is governed by the `variants:` field in the front matter of the page source.
For each variant you specify `+` to include or `-` to exclude it.
For example:
```markdown
---
title: My Page
variants: -flyte +serverless +byoc -selfmanaged
---
```
In this example the page will be:
* Included in Serverless and BYOC.
* Excluded from Flyte and Self-managed.
> [!NOTE]
> All variants must be explicitly listed in the `variants` field.
> This helps avoid missing or extraneous pages.
## Page order
Pages are ordered by the value of `weight` field (an integer >= 0) in the frontmatter of the page,
1. The higher the weight the lower the page sits in navigation ordering among its peers in the same folder.
2. Pages with no weight field (or `weight = 0`) will be ordered last.
3. Pages of the same weight will be sorted alphabetically by their title.
4. Folders are ordered among their peers (other folders and pages at the same level of the hierarchy) by the weight of their `_index.md` page.
For example:
```markdown
---
title: My Page
weight: 3
---
```
## Page settings
| Setting | Type | Description |
| ------------------ | ---- | --------------------------------------------------------------------------------- |
| `top_menu` | bool | If `true` the item becomes a tab at the top and its hierarchy goes to the sidebar |
| `sidebar_expanded` | bool | If `true` the section becomes expanded in the sidebar. Permanently. |
| `site_root` | bool | If `true` indicates that the page is the site landing page |
| `toc_max` | int | Maximum heading to incorporate in the right navigation table of contents. |
## Conditional Content
The site has "flavors" of the documentation. We leverage the `{{* variant */>}}` tag to control
which content is rendered on which flavor.
Refer to **Contributing docs and examples > Shortcodes > Variants** for detailed explanation.
## Warnings and Notices
You can write regular Markdown and use the notation below to create information and warning boxes:
```markdown
> [!NOTE] This is the note title
> You write the note content here. It can be
> anything you want.
```
Or if you want a warning:
```markdown
> [!WARNING] This is the title of the warning
> And here you write what you want to warn about.
```
## Special Content Generation
There are various short codes to generate content or special components (tabs, dropdowns, etc.)
Refer to **Contributing docs and examples > Shortcodes** for more information.
## Python Generated Content
You can generate pages from markdown-commented Python files.
At the top of your `.md` file, add:
```markdown
---
layout: py_example
example_file: /path/to/your/file.py
run_command: union run --remote tutorials//path/to/your/file.py main
source_location: https://www.github.com/unionai/unionai-examples/tree/main/tutorials/path/to/your/file.py
---
```
Where the referenced file looks like this:
```python
# # Credit Default Prediction with XGBoost & NVIDIA RAPIDS
#
# In this tutorial, we will use NVIDIA RAPIDS `cudf` DataFrame library for preprocessing
# data and XGBoost, an optimized gradient boosting library, for credit default prediction.
# We'll learn how to declare NVIDIA `A100` for our training function and `ImageSpec`
# for specifying our python dependencies.
# {{run-on-union}}
# ## Declaring workflow dependencies
#
# First, we start by importing all the dependencies that is required by this workflow:
import os
import gc
from pathlib import Path
from typing import Tuple
import fsspec
from flytekit import task, workflow, current_context, Resources, ImageSpec, Deck
from flytekit.types.file import FlyteFile
from flytekit.extras.accelerators import A100
```
Note that the text content is embedded in comments as Markdown, and the code is normal python code.
The generator will convert the markdown into normal page text content and the code into code blocks within that Markdown content.
### Run on Union Instructions
We can add the run on Union instructions anywhere in the content.
Annotate the location you want to include it with `{{run-on-union}}`. Like this:
```markdown
# The quick brown fox wants to see the Union instructions.
#
# {{run-on-union}}
#
# And it shall have it.
```
The resulting **Run on Union** section in the rendered docs will include the run command and source location,
specified as `run_command` and `source_location` in the front matter of the corresponding `.md` page.
## Jupyter Notebooks
You can also generate pages from Jupyter notebooks.
At the top of your.md file, add:
---
jupyter_notebook: /path/to/your/notebook.ipynb
---
Then run the `Makefile.jupyter` target to generate the page.
```shell
$ make -f Makefile.jupyter
```
> [!NOTE]
> You must `uv sync` and activate the environment in `tools/jupyter_generator` before running the
> `Makefile.jupyter` target, or make sure all the necessary dependencies are installed for yourself.
**Committing the change:** When the PR is pushed, a check for consistency between the notebook and its source will run. Please ensure that if you change the notebook, you re-run the `Makefile.jupyter` target to update the page.
## Mapped Keys (`{{* key */>}}`)
Key is a very special command that allows us to define mapped values to a variant.
For example, the product name changes if it is Flyte, Union BYOC, etc. For that,
we can define a single key `product_full_name` and map it to reflect automatically,
without the need to `if variant` around it.
Please refer to **Contributing docs and examples > Authoring > {{* key */>}} shortcode** for more details.
## Mermaid Graphs
To embed Mermaid diagrams in a page, insert the code inside a block like this:
```mermaid
your mermaid graph goes here
```
Also add `mermaid: true` to the top of your page to enable rendering.
> [!NOTE]
> You can use [Mermaid's playground](https://www.mermaidchart.com/play) to design diagrams and get the code
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/shortcodes ===
# Shortcodes
This site has special blocks that can be used to generate code for Union.
> [!NOTE]
> You can see examples by running the dev server and visiting
> [`http://localhost:1313/__docs_builder__/shortcodes/`](`http://localhost:1313/__docs_builder__/shortcodes/`).
> Note that this page is only visible locally. It does not appear in the menus or in the production build.
>
> If you need instructions on how to create the local environment and get the
> `localhost:1313` server running, please refer to the **Contributing docs and examples > Shortcodes > local development guide**.
## How to specify a "shortcode"
The shortcode is a string that is used to generate the HTML that is displayed.
You can specify parameters, when applicable, or have content inside it, if applicable.
> [!NOTE]
> If you specify content, you have to have a close tag.
Examples:
* A shortcode that just outputs something
```markdown
{{* key product_name */>}}
```
* A shortcode that has content inside
```markdown
{{* markdown */>}}
* You markdown
* goes here
{{* /markdown */>}}
```
* A shortcode with parameters
```markdown
{{* link-card target="union-sdk" icon="workflow" title="Union SDK" */>}}
The Union SDK provides the Python API for building Union workflows and apps.
{{* /link-card */>}}
```
> [!NOTE]
> If you're wondering why we have a `{{* markdown */>}}` when we can generate markdown at the top level, it is due to a quirk in Hugo:
> * At the top level of the page, Hugo can render markdown directly, interspersed with shortcodes.
> * However, *inside* a container shortcode, Hugo can only render *either* other shortcodes *or* Markdown.
> * The `{{* markdown */>}}` shortcode is designed to contain only Markdown (not other shortcodes).
> * All other container shortcodes are designed to contain only other shortcodes.
## Variants
The big difference of this site, compared to other documentation sites, is that we generate multiple "flavors" of the documentation that are slightly different from each other. We are calling these "variants."
When you are writing your content, and you want a specific part of the content to be conditional to a flavor, say "BYOC", you surround that with `variant`.
>[!NOTE]
> `variant` is a container, so inside you will specify what you are wrapping.
> You can wrap any of the shortcodes listed in this document.
Example:
```markdown
{{* variant serverless byoc */>}}
{{* markdown */>}}
**The quick brown fox signed up for Union!**
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
## Component Library
### `{{* audio */>}}`
Generates an audio media player.
### `{{* grid */>}}`
Creates a fixed column grid for lining up content.
### `{{* variant */>}}`
Filters content based on which flavor you're seeing.
### `{{* link-card */>}}`
A floating, clickable, navigable card.
### `{{* markdown */>}}`
Generates a markdown block, to be used inside containers such as `{{* dropdown */>}}` or `{{* variant */>}}`.
### `{{* multiline */>}}`
Generates a multiple line, single paragraph. Useful for making a multiline table cell.
### `{{* tabs */>}}` and `{{* tab */>}}`
Generates a tab panel with content switching per tab.
### `{{* key */>}}`
Outputs one of the pre-defined keywords.
Enables inline text that differs per-variant without using the heavy-weight `{{* variant>}}...{{* /variant */>}}` construct.
Take, for example, the following:
```markdown
The {{* key product_name */>}} platform is awesome.
```
In the Flyte variant of the site this will render as:
> The Flyte platform is awesome.
While, in the BYOC, Self-managed and Serverless variants of the site it will render as:
> The Union.ai platform is awesome.
You can add keywords and specify their value, per variant, in `hugo.toml`:
```toml
[params.key.product_full_name]
flyte = "Flyte"
serverless = "Union Serverless"
byoc = "Union BYOC"
selfmanaged = "Union Self-managed"
```
#### List of available keys
| Key | Description | Example Usage (Flyte β Union) |
| ----------------- | ------------------------------------- | ---------------------------------------------------------------------- |
| default_project | Default project name used in examples | `{{* key default_project */>}}` β "flytesnacks" or "default" |
| product_full_name | Full product name | `{{* key product_full_name */>}}` β "Flyte OSS" or "Union.ai Serverless" |
| product_name | Short product name | `{{* key product_name */>}}` β "Flyte" or "Union.ai" |
| product | Lowercase product identifier | `{{* key product */>}}` β "flyte" or "union" |
| kit_name | SDK name | `{{* key kit_name */>}}` β "Flytekit" or "Union" |
| kit | Lowercase SDK identifier | `{{* key kit */>}}` β "flytekit" or "union" |
| kit_as | SDK import alias | `{{* key kit_as */>}}` β "fl" or "union" |
| kit_import | SDK import statement | `{{* key kit_import */>}}` β "flytekit as fl" or "union" |
| kit_remote | Remote client class name | `{{* key kit_remote */>}}` β "FlyteRemote" or "UnionRemote" |
| cli_name | CLI tool name | `{{* key cli_name */>}}` β "Pyflyte" or "Union" |
| cli | Lowercase CLI tool identifier | `{{* key cli */>}}` β "pyflyte" or "union" |
| ctl_name | Control tool name | `{{* key ctl_name */>}}` β "Flytectl" or "Uctl" |
| ctl | Lowercase control tool identifier | `{{* key ctl */>}}` β "flytectl" or "uctl" |
| config_env | Configuration environment variable | `{{* key config_env */>}}` β "FLYTECTL_CONFIG" or "UNION_CONFIG" |
| env_prefix | Environment variable prefix | `{{* key env_prefix */>}}` β "FLYTE" or "UNION" |
| docs_home | Documentation home URL | `{{* key docs_home */>}}` β "/docs/flyte" or "/docs/serverless" |
| map_func | Map function name | `{{* key map_func */>}}` β "map_task" or "map" |
| logo | Logo image filename | `{{* key logo */>}}` β "flyte-logo.svg" or "union-logo.svg" |
| favicon | Favicon image filename | `{{* key favicon */>}}` β "flyte-favicon.ico" or "union-favicon.ico" |
### `{{* download */>}}`
Generates a download link.
Parameters:
- `url`: The URL to download from
- `filename`: The filename to save the file as
- `text`: The text to display for the download link
Example:
```markdown
{{* download "/_static/public/public-key.txt" "public-key.txt" */>}}
```
### `{{* docs_home */>}}`
Produces a link to the home page of the documentation for a specific variant.
Example:
```markdown
[See this in Flyte]({{* docs_home flyte>}}/wherever/you/want/to/go/in/flyte/docs)
```
### `{{* py_class_docsum */>}}`, `{{* py_class_ref */>}}`, and `{{* py_func_ref */>}}`
Helper functions to track Python classes in Flyte documentation, so we can link them to
the appropriate documentation.
Parameters:
- name of the class
- text to add to the link
Example:
```markdown
Please see {{* py_class_ref flyte.core.Image */>}} for more details.
```
### `{{* icon name */>}}`
Uses a named icon in the content.
Example:
```markdown
[Download {{* icon download */>}}](/download)
```
### `{{* code */>}}`
Includes a code snippet or file.
Parameters:
- `file`: The path to the file to include.
- `fragment`: The name of the fragment to include.
- `from`: The line number to start including from.
- `to`: The line number to stop including at.
- `lang`: The language of the code snippet.
- `show_fragments`: Whether to show the fragment names in the code block.
- `highlight`: Whether to highlight the code snippet.
The examples in this section uses this file as base:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
Link to [/_static/__docs_builder__/sample.py](/_static/__docs_builder__/sample.py)
#### Including a section of a file: `{{docs-fragment}}`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" fragment=entrypoint lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a file with a specific line range: `from` and `to`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" from=2 to=4 lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a whole file
Simply specify no filters, just the `file` attribute:
```markdown
{{* code file="/_static/__docs_builder__/sample.py" */>}}
```
> [!NOTE]
> Note that without `show_fragments=true` the fragment markers will not be shown.
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/redirects ===
# Redirects
We use Cloudflare's Bulk Redirect to map URLs that moved to their new location,
so the user does not get a 404 using the old link.
The direct files are in CSV format, with the following structure:
`,,302,TRUE,FALSE,TRUE,TRUE`
- ``: the URL without `https://`
- ``: the full URL (including `https://`) to send the user to
Redirects are recorded in `redirects.csv` file in the root of the repository.
To take effect, this file must be applied to the production environment on CloudFlare by a Union employee.
If you need to add a new redirect, please create a pull request with the change to `redirect.csv` and a note indicating that you would like to have it applied to production.
## `docs.union.ai` redirects
For redirects from the old `docs.union.ai` site to the new `www.union.ai/docs` site, we use the original request URL. For example:
|
|-|-|
| Request URL | `https://docs.union.ai/administration` |
| Target URL | `/docs/v1/byoc//user-guide/administration` |
| Redirect Entry | `docs.union.ai/administration,/docs/v1/byoc//user-guide/administration,302,TRUE,FALSE,TRUE,TRUE` |
## `docs.flyte.org` redirects
For directs from the old `docs.flyte.org` to the new `www.union.ai/docs`, we replace the `docs.flyte.org` in the request URL with the special prefix `www.union.ai/_r_/flyte`. For example:
|
|-|-|
| Request URL | `https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Converted request URL | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Target URL | `/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/` |
| Redirect Entry | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html,/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/,302,TRUE,FALSE,TRUE,TRUE` |
The special prefix is used so that we can include both `docs.union.ai` and `docs.flyte.org` redirects in the same file and apply them on the same domain (`www.union.ai`).
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/api-docs ===
# API docs
You can import Python APIs and host them on the site. To do that you will use
the `tools/api_generator` to parse and create the appropriate markdown.
Please refer to [`api_generator/README`](https://github.com/unionai/docs/blob/main/tools/api_generator/README.md) for more details.
## API naming convention
All the buildable APIs are at the root in the form:
`Makefile.api.`
To build it, run `make -f Makefile.api.` and observe the setup
requirements in the `README.md` file above.
## Package Resource Resolution
When scanning the packages we need to know when to include or exclude an object
(class, function, variable) from the documentation. The parser will follow this
workflow to decide, in order, if the resource must be in or out:
1. `__all__: List[str]` package-level variable is present: Only resources
listed will be exposed. All other resources are excluded.
Example:
```python
from http import HTTPStatus, HTTPMethod
__all__ = ["HTTPStatus", "LocalThingy"]
class LocalThingy:
...
class AnotherLocalThingy:
...
```
In this example only `HTTPStatus` and `LocalThingy` will show in the docs.
Both `HTTPMethod` and `AnotherLocalThingy` are ignored.
2. If `__all__` is not present, these rules are observed:
- All imported packages are ignored
- All objects starting with `_` are ignored
Example:
```python
from http import HTTPStatus, HTTPMethod
class _LocalThingy:
...
class AnotherLocalThingy:
...
def _a_func():
...
def b_func():
...
```
In this example only `AnotherLocalThingy` and `b_func` will show in the docs.
Neither none of the imports nor `_LocalThingy` will show in the documentation.
## Tips and Tricks
1. If you either have no resources without a `_` nor an `__all__` to
export blocked resources (imports or starting with `_`, the package will have no content and thus will not be generated.
2. If you want to export something you `from ___ import ____` you _must_
use `__all__` to add the private import to the public list.
3. If all your methods follow the Python convention of everything private starts
with `_` and everything you want public does not, you do not need to have a
`__all__` allow list.
=== PAGE: https://www.union.ai/docs/v2/flyte/community/contributing-docs/publishing ===
# Publishing
## Requirements
1. Hugo (https://gohugo.io/)
```shell
$ brew install hugo
```
2. A preferences override file with your configuration
The tool is flexible and has multiple knobs. Please review `hugo.local.toml~sample`, and configure to meet your preferences.
```shell
$ cp hugo.local.toml~sample hugo.local.toml
```
3. Make sure you review `hugo.local.toml`.
## Managing the Tutorial Pages
The tutorials are maintained in the [unionai/unionai-examples](https://github.com/unionai/unionai-examples) repository and is imported as a git submodule in the `external`
directory.
To initialize the submodule on a fresh clone of this (`docs-builder`) repo, run:
```
$ make init-examples
```
To update the submodule to the latest `main` branch, run:
```
$ make update-examples
```
## Building and running locally
```
$ make dev
```
## Building Production
```
$ make dist
```
### Testing Production Build
You can run a local web server and serve the `dist/` folder. The site must behave correctly, as it would be in its official URL.
To start a server:
```
$ make serve PORT=
```
Example:
```
$ make server PORT=4444
```
Then you open the browser on `http://localhost:` to see the content. In the example above, it would be `http://localhost:4444/`
This will create all the variants into the `dist` folder.
## Developer Experience
This will launch the site in development mode.
The changes are hot reloaded: just change in your favorite editor and it will refresh immediately on the browser.
### Controlling Development Environment
You can change how the development environment works by settings values in `hugo.local.toml`. The following settings are available:
* `variant` - The current variant to display. Change this in 'hugo.toml', save, and the browser will refresh automatically
with the new variant.
* `show_inactive` - If 'true', it will show all the content that did not match the variant.
This is useful when the page contains multiple sections that vary with the selected variant,
so you can see all at once.
* `highlight_active` - If 'true', it will also highlight the *current* content for the variant.
* `highlight_keys` - If 'true'', it highlights replacement keys and their values
### Changing 'variants'
Variants are flavors of the site (that you can change at the top).
During development, you can render any variant by setting it in `hugo.local.toml`:
```
variant = "byoc"
```
We call this the "active" variant.
You can also render variant content from other variants at the same time as well as highlighting the content of your active variant:
To show the content from variants other than the currently active one set:
```
show_inactive = true
```
To highlight the content of the currently active variant (to distinguish it from common content that applies to all variants), set:
```
highlight_active = true
```
> You can create you own copy of `hugo.local.toml` by copying from `hugo.local.toml~sample` to get started.
## Troubleshootting
### Identifying Problems: Missing Content
Content may be hidden due to `{{* variant */>}}` blocks. To see what's missing,
you can adjust the variant show/hide in development mode.
For a production-like look set:
show_inactive = false
highlight_active = false
For a full-developer experience, set:
show_inactive = true
highlight_active = true
### Identifying Problems: Page Visibility
The developer site will show you in red any pages missing from the variant.
For a page to exist in the variant (or be excluded, you have to pick one), it must be listed in the `variants:` at the top of the file.
Clicking on the red page will give you the path you must add to the appropriate variant in the YAML file and a link with guidance.
Please refer to **Contributing docs and examples > Authoring** for more details.
## Building Production
```
$ make dist
```
This will build all the variants and place the result in the `dist` folder.
### Testing Production Build
You can run a local web server and serve the `dist/` folder. The site must behave correctly, as it would be in its official URL.
To start a server:
```
$ make serve [PORT=]
```
If specified without parameters, defaults to PORT=9000.
Example:
```
$ make serve PORT=4444
```
Then you open the browser on `http://localhost:` to see the content. In the example above, it would be `http://localhost:4444/`
=== PAGE: https://www.union.ai/docs/v2/flyte/release-notes ===
# Release notes
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment ===
# Platform deployment
Flyte is distributed as a Helm chart with different supported deployment scenarios.
Flyte is the platform built on top of Flyte that extends its capabilities to include RBAC, instant containers, real-time serving and more.
The following diagram describes the available deployment paths for both options:
```mermaid
flowchart TD
A("Deployment paths") --> n1["Testing/evaluating"] & n4["Production deployment"]
n1 -- Flyte+Flyte in the browser --> n2["Flyte Serverless "]
n1 -- Compact Flyte cluster in a local container --> n3["flytectl demo start "]
n4 --> n5["Run Flyte"] & n8["Run Flyte"]
n5 -- small scale --> n6["flyte-binary Helm chart"]
n5 -- large scale or multi-cluster --> n7["flyte-core Helm chart"]
n8 -- "You manage your data plane. Flyte manages the control plane" --> n9["Self-managed"]
n8 -- Flyte manages control and data planes --> n10["BYOC"]
n1@{ shape: diam}
n4@{ shape: rounded}
n2@{ shape: rounded}
n3@{ shape: rounded}
n5@{ shape: diam}
n8@{ shape: diam}
n6@{ shape: rounded}
n7@{ shape: rounded}
n9@{ shape: rounded}
n10@{ shape: rounded}
```
This section walks you through the process to create a Flyte cluster and cover topics related to enabling and configuring plugins, authentication, performance tuning, and maintaining Flyte as a production-grade service.
## Subpages
- **Flyte deployment**
- **Platform configuration**
- **Connector setup**
- **Plugins**
- **Configuration reference**
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-deployment ===
# Flyte deployment
This section covers Flyte deployment.
## Subpages
- **Flyte deployment > Components of a Flyte deployment**
- **Flyte deployment > Installing Flyte**
- **Flyte deployment > Multi-cluster**
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-deployment/planning ===
# Components of a Flyte deployment
A Flyte cluster is composed of 3 logical planes as described in the table:
| Plane | Description | Component |
|---|---|---|
| User plane | Tools to interact with the API | `flytekit`, `flytectl`, and `pyflyte` |
| Control plane | Processes incoming requests, implements core logic, maintains metadata and resource inventory. | `flyteadmin`, `datacatalog`, and `flytescheduler`. |
| Data plane | It fulfills execution requests, including instantiating plugins/connectors. | `flytepropeller`, `clusterresourcessync` |
# External dependencies
Regardless of the deployment path you choose, Flyte relies on a few elements to operate.
## Kubernetes cluster
It's recommended to a [supported Kubernetes version](https://kubernetes.io/releases/version-skew-policy/#supported-versions) . Flyte doesn't impose a requirement on the provider or method you use to stand up the K8s cluster: it can be anything from `k3s` on edge devices to massive K8s environments in the cloud or on-prem bare metal.
## Relational Database
Both `flyteadmin` and `datacatalog` rely on a PostgreSQL 12+ instance to store persistent records.
## Object store
Core Flyte components such as `flyteadmin`, `flytepropeller`, `datacatalog`, and user runtime containers -spawned for each execution- rely on an object store to hold files.
A Flyte deployment requires at least one storage bucket from an S3-compliant provider with the following minimum permissions:
- DeleteObject
- GetObject
- ListBucket
- PutObject
## Optional dependencies
Flyte can be operated without the following elements, but is prepared to use them if available for better integration with your current infrastructure:
### Ingress controller
Flyte operates with two protocols: `HTTP` for the UI and `gRPC` for the client-to-control-plane communication. You can expose both ports through `port-forward` which is typically a temporary measure, or expose them in a stable manner using Ingress. For a Kubernetes Ingress resource to be properly materialized, it needs an Ingress controller already installed in the cluster.
The Flyte Helm charts can trigger the creation of the Ingress resource but the config needs to be reconciled by an Ingress controller (doesn't ship with Flyte).
The Flyte community has used the following controllers succesfully:
| Environment | Controller | Example configuration |
|---|---|---|
| AWS | ALB | [flyte-binary config](https://github.com/flyteorg/flyte/blob/754ab74b29f5fee665fd1cfde38fccccd95af8bd/charts/flyte-binary/eks-starter.yaml#L108-L120) / [flyte-core config](https://github.com/flyteorg/flyte/blob/754ab74b29f5fee665fd1cfde38fccccd95af8bd/charts/flyte-core/values-eks.yaml#L142-L160) |
| GCP | NGINX | [flyte-core example config](https://github.com/flyteorg/flyte/blob/754ab74b29f5fee665fd1cfde38fccccd95af8bd/charts/flyte-core/values-gcp.yaml#L160-L173) |
| Azure | NGINX | [flyte-core example config](https://github.com/flyteorg/flyte/blob/754ab74b29f5fee665fd1cfde38fccccd95af8bd/charts/flyte-core/values-gcp.yaml#L160-L173) |
| On-prem | NGINX, Traefik |
### DNS
To register and run workflows in Flyte, your client (the CLI in your machine or an external system) needs to connect to the Flyte control plane through an endpoint. When you do `port-forward`, you typically access Flyte through `localhost`. For a production environment is recommended to use a valid DNS entry that points to your Ingress host name.
### SSL/TLS
Use a valid certificate to secure the communication between your client and the Flyte control plane. For Flyte, `insecure: true` means no certificate is installed. You can even use self-signed certificates (which counts as `insecure: false`) adding the `insecureSkipVerify: true` key to the local `config.yaml` file. That will inform Flyte to skip verifying the certificate chain.
## Helm chart variants
### Sandbox
It packages Flyte and all its dependencies into a single container that runs locally.
When you run `flytectl demo start` it creates the container using any OCI-compliant container engine you have available in your local system.
### flyte-binary
It packages all the Flyte components in a single Pod and is designed to scale up by adding more compute resources to the Deployment.
It doesn't implement the dependencies so you have to provision the storage bucket, Kubernetes cluster and database before installing it.
The repo includes [example values files](https://github.com/flyteorg/flyte/tree/master/charts/flyte-binary) for different environments.
> The [Flyte the Hard Way](https://github.com/davidmirror-ops/flyte-the-hard-way) community-maintained guide walks you through the semiautomated process to prepare the dependencies to install `flyte-binary`
### flyte-core
It runs each Flyte component as a highly-available Deployment. The main difference with the flyte-binary chart is that flyte-core supports scaling out each Flyte component independently.
## Additional resources
### Terraform reference implementations
Flyte maintains a [Terraform codebase](https://github.com/unionai-oss/deploy-flyte) you can use to automatically configure all the dependencies and install Flyte in AWS, GCP, or Azure.
### Support
Reach out to the [#flyte-deployment](https://flyte-org.slack.com/archives/C01P3B761A6) community channel if you have questions during the deployment process.
[Flyte](https://www.union.ai/contact) also offers paid Install Assist and different tiers of support services.
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-deployment/installing ===
# Installing Flyte
First, add the Flyte chart repo to Helm:
```bash
helm repo add flyteorg https://flyteorg.github.io/flyte
```
Then download and update a values file:
```bash
curl -sL https://raw.githubusercontent.com/flyteorg/flyte/master/charts/flyte-binary/eks-starter.yaml
```
> Both the [flyte-binary](https://github.com/flyteorg/flyte/tree/master/charts/flyte-binary) and [flyte-core](https://github.com/flyteorg/flyte/tree/master/charts/flyte-core) charts include example YAML values files for different cloud environments.
You can provide your own values file overriding the base config. The minimum information required for each chart is detailed in the following table:
| Required config | `flyte-binary` key |`flyte-core` key | Notes |
|---|---|---|---|
| Database password | `configuration.database.password` | `userSettings.dbPassword` | Default Postgres username: `postgres` |
| Database server | `configuration.database.host` |`userSettings.dbHost` (GCP and Azure), `userSettings.rdsHost`(EKS) | Default DB name: `flyteadmin`|
| S3 storage bucket | `configuration.storage.metadataContainer` / `configuration.storage.userDataContainer` |`userSettings.bucketName` / `userSettings.rawDataBucketName` | You can use the same bucket for both|
Once adjusted your values file, install the chart:
Example:
```bash
helm install flyte-backend flyteorg/flyte-binary \
--dry-run --namespace flyte --values eks-starter.yaml
```
When ready to install, remove the `--dry-run` switch.
## Verify the Installation
The base values files provide only the simplest installation of Flyte. The core functionality and scalability of Flyte will be there but not Ingress, authentication or DNS/SSL is configured.
### Port Forward Flyte Service
To verify the installation you can to port forward the Kubernetes service:
Example:
```bash
kubectl -n flyte port-forward service/flyte-binary-http 8088:8088
kubectl -n flyte port-forward service/flyte-binary-grpc 8089:8089
```
You should be able to navigate to `http://localhost:8088/console`.
The Flyte server operates on two different ports, one for `HTTP` traffic and the other for `gRPC`, which is why we port forward both.
### Connect to your Flyte instance
- Generate a new configuration file (in case you don't have one already) using `flytectl config init`.
This will produce a file like the following:
```yaml
admin:
# For GRPC endpoints you might want to use dns:///flyte.myexample.com
endpoint: dns:///localhost:8089 #the gRPC endpoint
authType: Pkce
insecure: true
logger:
show-source: true
level: 0
```
- Test your connection using:
```bash
flytectl get projects
```
From this point on you can start running workflows!
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-deployment/multicluster ===
# Multi-cluster
The multi-cluster deployment described in this section assumes that you have deployed the `flyte-core` helm chart, which runs the individual flyte components separately.
This is needed because in a multi-cluster setup, the execution engine (`flytepropeller`) is deployed to multiple k8s clusters; hence it wouldn't work with the `flyte-binary` helm chart, since it deploys all flyte services as one single binary.
> [!NOTE]
> Union.ai offers simplified support for multi-cluster and multi-cloud.
> [Learn more](/docs/v1/byoc//deployment/multi-cluster#multi-cluster-and-multi-cloud) or [book a demo](https://union.ai/demo).
## Scaling Beyond Kubernetes
As described in the [Architecture Overview](https://docs.flyte.org/en/latest/concepts/architecture.html), the Flyte control plane (`flyteadmin`) sends workflows off to the Data Plane (`flytepropeller`) for execution.
The data plane fulfills these workflows by launching pods in Kubernetes.
The case for multiple Kubernetes clusters may arise due to security constraints, cost-effectiveness or a need to scale out computing resources.
To address this, you can deploy Flyte's data plane to multiple Kubernetes clusters.
The control plane (`flyteadmin`) can be configured to submit workflows to these individual data planes.
Additionally, Flyte provides the mechanisms for administrators to retain control on the workflow placement logic while enabling users to reap the benefits using simple abstractions like `projects` and `domains`.
### Prerequisites
To make sure that your multi-cluster deployment is able to scale and process requests successfully, the following environment-specific requirements should be met:
1. An IAM Policy that defines the permissions needed for Flyte. A minimum set of permissions include:
```json
"Action": [
"s3:DeleteObject*",
"s3:GetObject*",
"s3:ListBucket",
"s3:PutObject*"
],
"Resource": [
"arn:aws:s3:::*",
"arn:aws:s3:::*/*"
],
```
2. Two IAM Roles configured: one for the control plane components, and another for the data plane where the worker Pods and `flytepropeller` run.
Use the recommended security strategy for the cloud provider you're running on.
For example, IRSA for EKS environments or Workload Identity Federation for GCP.
3. Mapping between the `default` Service Account in each `project-domain` namespace and the assumed role in your cloud environment.
By default, every Pod created for a Task execution, uses the `default` Service Account in their respective namespace.
In your cluster, you'll have as many namespaces as `project` and `domain` combinations you may have.
### Data Plane Deployment
This guide assumes that you have two Kubernetes clusters and that you can access them all with `kubectl`.
Let's call these clusters `dataplane1` and `dataplane2`. In this section, you'll prepare the first cluster only.
1. Add the `flyteorg` Helm repo:
```shell
$ helm repo add flyteorg https://flyteorg.github.io/flyte
$ helm repo update
```
2. Get the `flyte-core` Helm chart:
```shell
$ helm fetch --untar --untardir . flyteorg/flyte-core
$ cd flyte-core
```
3. Open the `values-dataplane.yaml` file and add the following contents:
```yaml
configmap:
admin:
admin:
endpoint: :443 #indicate the URL you're using to connect to Flyte
insecure: false #enables secure communication over SSL. Requires a signed certificate
catalog:
catalog-cache:
endpoint:
insecure: false
```
This step is needed so the `flytepropeller` instance in the data plane cluster is able to send notifications back to the `flyteadmin` service in the control plane.
The `catalog` service runs in the control plane and is used when caching is enabled.
Note that `catalog` is not exposed via the ingress by default and does not have its own authentication mechanism.
The `catalog` service in the control plane cluster can, for instance, be made available to the `flytepropeller` services in the data plane clusters with an internal load balancer service.
See [GKE documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/internal-load-balancing#create>) or
[AWS Load Balancer Controller](https://kubernetes-sigs.github.io/aws-load-balancer-controller/latest/guide/service/nlb) if the clusters use the same VPC network.
4. Install the Flyte data plane Helm chart. Use the same base `values` file you used to deploy the control plane:
**AWS**
```bash
$ helm install flyte-core-data flyteorg/flyte-core -n flyte \
--values values-eks.yaml --values values-dataplane.yaml \
--create-namespace
```
**GCP**
```bash
$ helm install flyte-core-data -n flyte flyteorg/flyte-core \
--values values-gcp.yaml \
--values values-dataplane.yaml \
--create-namespace flyte
```
## Control Plane configuration
For `flyteadmin` to access and create Kubernetes resources in one or more Flyte data plane clusters, it needs credentials to each cluster.
Flyte makes use of Kubernetes Service Accounts to enable every control plane cluster to perform authenticated requests to the Kubernetes API Server in the data plane cluster.
The default behavior is that the Helm chart creates a [ServiceAccount](https://github.com/flyteorg/flyte/blob/master/charts/flyte-core/templates/admin/rbac.yaml#L4)in each data plane cluster.
In order to verify requests, the Kubernetes API Server expects a [signed bearer token](https://kubernetes.io/docs/reference/access-authn-authz/authentication/#service-account-tokens) attached to the Service Account.
Starting with Kubernetes 1.24, the bearer token has to be generated manually.
1. Use the following manifest to create a long-lived bearer token for the `flyteadmin` Service Account in your data plane cluster:
```shell
$ kubectl apply -f - < [!NOTE]
> The credentials have two parts (`CA cert` and `bearer token`).
3. Copy the bearer token of the first data plane cluster's secret to your clipboard using the following command:
```shell
$ kubectl get secret -n flyte dataplane1-token \
-o jsonpath='{.data.token}' | pbcopy
```
4. Go to `secrets.yaml` and add a new entry under `stringData` with the data plane cluster token:
```yaml
apiVersion: v1
kind: Secret
metadata:
name: cluster-credentials
namespace: flyte
type: Opaque
data:
dataplane_1_token:
```
5. Obtain the corresponding certificate:
```shell
$ kubectl get secret -n flyte dataplane1-token \
-o jsonpath='{.data.ca\.crt}' | pbcopy
```
6. Add another entry in your `secrets.yaml` file for the certificate:
```yaml
apiVersion: v1
kind: Secret
metadata:
name: cluster-credentials
namespace: flyte
type: Opaque
data:
dataplane_1_token:
dataplane_1_cacert:
```
7. Connect to your control plane cluster and create the `cluster-credentials` secret:
```shell
$ kubectl apply -f secrets.yaml
```
8. Create a file named `values-override.yaml` and add the following config to it:
```yaml
flyteadmin:
additionalVolumes:
- name: cluster-credentials
secret:
secretName: cluster-credentials
additionalVolumeMounts:
- name: cluster-credentials
mountPath: /var/run/credentials
initContainerClusterSyncAdditionalVolumeMounts:
- name: cluster-credentials
mountPath: /etc/credentials
configmap:
clusters:
labelClusterMap:
label1:
- id: dataplane_1
weight: 1
clusterConfigs:
- name: "dataplane_1"
endpoint: https://:443
enabled: true
auth:
type: "file_path"
tokenPath: "/var/run/credentials/dataplane_1_token"
certPath: "/var/run/credentials/dataplane_1_cacert"
```
> [!NOTE]
> Typically, you can obtain your Kubernetes API endpoint URL using `kubectl cluster-info`
In this configuration, `label1` and `label2` are just labels that we will use later in the process to configure mappings that enable workflow executions matching those labels, to be scheduled on one or multiple clusters depending on the weight (e.g. `label1` on `dataplane_1`). The `weight` is the priority of a specific cluster, relative to the other clusters under the `labelClusterMap` entry. The total sum of weights under a particular label has to be `1`.
9. Add the data plane IAM Role as the `defaultIamRole` in your Helm values file. [See AWS example](https://github.com/flyteorg/flyte/blob/97a79c030555eaefa3e27383d9b933ba1fdc1140/charts/flyte-core/values-eks.yaml#L351-L365)
10. Update the control plane Helm release:
This step will disable `flytepropeller` in the control plane cluster, leaving no possibility of running workflows there. If you require the control plane to run workflows, edit the `values-controlplane.yaml` file and set `flytepropeller.enabled` to `true` and add one additional cluster config for the control plane cluster itself:
```yaml
configmap:
clusters:
clusterConfigs:
- name: "dataplane_1"
...
- name: "controlplane"
enabled: true
inCluster: true # Use in-cluster credentials
```
Then, complete the `helm upgrade` operation.
**AWS**
```shell
$ helm upgrade flyte-core flyteorg/flyte-core \
--values values-eks-controlplane.yaml --values values-override.yaml \
--values values-eks.yaml -n flyte
```
**GCP**
```shell
$ helm upgrade flyte -n flyte flyteorg/flyte-core values.yaml \
--values values-gcp.yaml \
--values values-controlplane.yaml \
--values values-override.yaml
```
11. Verify that all Pods in the `flyte` namespace are `Running`:
```shell
$ kubectl get pods -n flyte
```
Example output:
```shell
NAME READY STATUS RESTARTS AGE
datacatalog-86f6b9bf64-bp2cj 1/1 Running 0 23h
datacatalog-86f6b9bf64-fjzcp 1/1 Running 0 23h
flyteadmin-84f666b6f5-7g65j 1/1 Running 0 23h
flyteadmin-84f666b6f5-sqfwv 1/1 Running 0 23h
flyteconsole-cdcb48b56-5qzlb 1/1 Running 0 23h
flyteconsole-cdcb48b56-zj75l 1/1 Running 0 23h
flytescheduler-947ccbd6-r8kg5 1/1 Running 0 23h
syncresources-6d8794bbcb-754wn 1/1 Running 0 23h
```
## Configure Execution Cluster Labels
The next step is to configure project-domain or workflow labels to schedule on a specific Kubernetes cluster.
### Project-domain execution labels
1. Create an `ecl.yaml` file with the following contents:
```yaml
domain: development
project: project1
value: label1
```
> [!NOTE]
> Change `domain` and `project` according to your environment. The `value` has to match with the entry under `labelClusterMap` in the `values-override.yaml` file.
2. Repeat step 1 for every project-domain mapping you need to configure, creating a YAML file for each one.
3. Update the execution cluster label of the project and domain:
```shell
$ flytectl update execution-cluster-label --attrFile ecl.yaml
```
Example output:
```shell
Updated attributes from team1 project and domain development
```
4. Execute a workflow indicating project and domain:
```shell
$ pyflyte run --remote --project team1 --domain development example.py training_workflow \ ξΊ β β± docs-development-env ξΌ
--hyperparameters '{"C": 0.1}'
```
### Configure a Specific Workflow mapping
1. Create a `workflow-ecl.yaml` file with the following example contents:
```yaml
domain: development
project: project1
workflow: example.training_workflow
value: project1
```
2. Update execution cluster label of the project and domain
```shell
$ flytectl update execution-cluster-label \
-p project1 -d development \
example.training_workflow \
--attrFile workflow-ecl.yaml
```
3. Execute a workflow indicating project and domain:
```shell
$ pyflyte run --remote --project team1 --domain development example.py training_workflow \ ξΊ β β± docs-development-env ξΌ
--hyperparameters '{"C": 0.1}'
```
Congratulations!
With this, the execution of workflows belonging to a specific
project-domain or a single specific workflow will be scheduled on the target label
cluster.
## Day 2 Operations
### Add another Kubernetes cluster
The process can be repeated for additional clusters.
1. Provision the new cluster and add it to the permissions structure (IAM, etc.).
2. Install the data plane Helm chart following the steps in the **Flyte deployment > Multi-cluster > Scaling Beyond Kubernetes > Data Plane Deployment** section.
3. Follow steps 1-3 in the **Flyte deployment > Multi-cluster > Control Plane configuration** to generate and populate a new section in your `secrets.yaml` file.
For example:
```yaml
apiVersion: v1
kind: Secret
metadata:
name: cluster-credentials
namespace: flyte
type: Opaque
data:
dataplane_1_token:
dataplane_1_cacert:
dataplane_2_token:
dataplane_2_cacert:
```
4. Connect to the control plane cluster and update the `cluster-credentials` Secret:
```bash
kubect apply -f secrets.yaml
```
5. Go to your `values-override.yaml` file and add the information of the new cluster.
Adding a new label is not entirely needed.
Nevertheless, in the following example a new label is created to illustrate Flyte's capability to schedule workloads on different clusters in response to user-defined mappings of `project`, `domain` and `label`:
```yaml
... #all the above content remains the same
configmap:
clusters:
labelClusterMap:
label1:
- id: dataplane_1
weight: 1
label2:
- id: dataplane_2
weight: 1
clusterConfigs:
- name: "dataplane_1"
endpoint: https://.com:443
enabled: true
auth:
type: "file_path"
tokenPath: "/var/run/credentials/dataplane_1_token"
certPath: "/var/run/credentials/dataplane_1_cacert"
- name: "dataplane_2"
endpoint: https://:443
enabled: true
auth:
type: "file_path"
tokenPath: "/var/run/credentials/dataplane_2_token"
certPath: "/var/run/credentials/dataplane_2_cacert"
```
6. Update the Helm release in the control plane cluster:
```shell
$ helm upgrade flyte-core-control flyteorg/flyte-core -n flyte --values values-controlplane.yaml --values values-eks.yaml --values values-override.yaml
```
7. Create a new execution cluster labels file with the following sample content:
```yaml
domain: production
project: team1
value: label2
```
8. Update the cluster execution labels for the project:
```shell
$ flytectl update execution-cluster-label --attrFile ecl-production.yaml
```
9. Finally, submit a workflow execution that matches the label of the new cluster:
```shell
$ pyflyte run --remote --project team1 --domain production example.py \
training_workflow --hyperparameters '{"C": 0.1}'
```
10. A successful execution should be visible on the UI, confirming it ran in the new cluster:

=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-configuration ===
# Platform configuration
This section covers configuring Flyte for deeper integrations with existing infrastructure.
## Subpages
- **Platform configuration > Configuring authentication**
- **Platform configuration > Monitoring a Flyte deployment**
- **Platform configuration > Configuring logging links in the UI**
- **Platform configuration > Configuring Access to GPUs**
- **Platform configuration > Configuring task pods with K8s PodTemplates**
- **Platform configuration > Cloud Events**
- **Platform configuration > Customizing project, domain, and workflow resources with flytectl**
- **Platform configuration > Platform Events**
- **Platform configuration > Workflow notifications**
- **Platform configuration > Optimizing Performance**
- **Platform configuration > Flyte ResourceManager**
- **Platform configuration > Secrets**
- **Platform configuration > Security Overview**
- **Platform configuration > Flyte API Playground: Swagger**
=== PAGE: https://www.union.ai/docs/v2/flyte/deployment/flyte-configuration/configuring-authentication ===
# Configuring authentication
The Flyte platform consists of multiple components. Securing communication between each component is crucial to ensure
the integrity of the overall system.
Flyte supports most of the [OAuth2.0](https://tools.ietf.org/html/rfc6749) authorization grants and use them to control access to workflow and task executions as the main protected resources.
Additionally, Flyte implements the [OIDC1.0](https://openid.net/specs/openid-connect-core-1_0.html) standard to attach user identity to the authorization flow. This feature requires integration with an external Identity Provider.
The following diagram illustrates how the elements of the OAuth2.0 protocol map to the Flyte components involved in the authentication process:
```mermaid
sequenceDiagram
participant Client (CLI/UI/system) as Client (CLI/UI/system)
participant flytepropeller as Resource Server + Owner (flytepropeller)
participant flyteadmin/external IdP as Authorization Server (flyteadmin/external IdP)
Client (CLI/UI/system) ->>+ flytepropeller: Authorization request
flytepropeller ->>+ flyteadmin/external IdP: Request authorization grant
flyteadmin/external IdP ->> flytepropeller: Issue authorization grant
flytepropeller ->> Client (CLI/UI/system): Authorization grant
Client (CLI/UI/system) ->> flyteadmin/external IdP: Authorization grant
flyteadmin/external IdP ->> Client (CLI/UI/system): Access token
Client (CLI/UI/system) ->> flytepropeller: Access token
flytepropeller ->> Client (CLI/UI/system): Protected resource
```
There are two main dependencies required for a complete auth flow in Flyte:
* **OIDC (Identity Layer) configuration** The OIDC protocol allows clients (such as Flyte) to confirm the identity of a user, based on authentication done by an Authorization Server.
To enable this, you first need to register Flyte as an app (client) with your chosen Identity Provider (IdP).
* **An authorization server** The authorization server job is to issue access tokens to clients for them to access the protected resources.
Flyte ships with two options for the authorization server:
* **Internal authorization server**: It's part of `flyteadmin` and is a suitable choice for quick start or testing purposes.
* **External (custom) authorization server**: This is a service provided by one of the supported IdPs and is the recommended option if your organization needs to retain control over scope definitions, token expiration policies and other advanced security controls.
> [!NOTE]
> Regardless of the type of authorization server to use, you will still need an IdP to provide identity through OIDC.
## Configuring the identity layer
### Prerequisites
* A public domain name (e.g. example.foobar.com)
* A DNS entry mapping the Fully Qualified Domain Name to the Ingress `host`.
> [!NOTE]
> Checkout this [community-maintained guide](https://github.com/davidmirror-ops/flyte-the-hard-way/blob/main/docs/06-intro-to-ingress.md) for more information about setting up Flyte in production, including Ingress.
### Configuring your IdP for OIDC
In this section, you can find canonical examples of how to set up OIDC on some of the supported IdPs; enabling users to authenticate in the
browser.
> [!NOTE]
> Using the following configurations as a reference, the community has succesfully configured auth with other IdPs as Flyte implements open standards.
#### Google
1. Create an OAuth2 Client Credential following the [official documentation](https://developers.google.com/identity/protocols/oauth2/openid-connect) and take note of the `client_id` and `client_secret`
2. In the **Authorized redirect URIs** field, add `http://localhost:30081/callback` for **sandbox** deployments or `https:///callback` for other deployment methods.
#### Okta
1. If you don't already have an Okta account, [sign up for one](https://developer.okta.com/signup/).
2. Create an app integration, with `OIDC - OpenID Connect` as the sign-on method and `Web Application` as the app type.
3. Add sign-in redirect URIs: `http://localhost:30081/callback` for sandbox or `https:///callback` for other Flyte deployment types.
4. *Optional* - Add logout redirect URIs: `http://localhost:30081/logout` for sandbox, `https:///callback` for other Flyte deployment methods.
5. Take note of the Client ID and Client Secret.
#### Keycloak
1. Create a realm using the [admin console](https://wjw465150.gitbooks.io/keycloak-documentation/content/server_admin/topics/realms/create.html).
2. [Create an OIDC client with client secret](https://wjw465150.gitbooks.io/keycloak-documentation/content/server_admin/topics/clients/client-oidc.html) and note them down.
3. Add Login redirect URIs: `http://localhost:30081/callback` for sandbox or `https:///callback` for other Flyte deployment methods.
#### Microsoft Entra ID
1. In the Azure portal, open Microsoft Entra ID from the left-hand menu.
2. From the Overview section, navigate to **App registrations** > **+ New registration**.
* Under Supported account types, select the option based on your organization's needs.
3. Configure Redirect URIs
* In the Redirect URI section, choose **Web** from the **Platform** dropdown and enter the following URIs based on your environment:
* Sandbox: `http://localhost:30081/callback`
* Production: `https:///callback`
4. Obtain Tenant and Client Information
* After registration, go to the app's Overview page.
* Take note of the Application (client) ID and Directory (tenant) ID. Youβll need these in your Flyte configuration.
5. Create a Client Secret
* From the Certificates & Secrets tab, click + New client secret.
* Add a Description and set an Expiration period (e.g., 6 months or 12 months).
* Click Add and copy the Value of the client secret; it will be used in the Helm values.
6. If the Flyte deployment will be dealing with user data, set API permissions:
* Navigate to **API Permissions > + Add a permission**, select **Microsoft Graph > Delegated permissions**, and add the following permissions:
* `email`
* `openid`
* `profile`
* `offline_access`
* `User.Read`
7. Expose an API (for Custom Scopes). In the Expose an API tab:
* Click + Add a scope, and set the Scope name (e.g., access_flyte).
* Provide a Consent description and enable Admin consent required and Save.
* Then, click + Add a client application and enter the Client ID of your Flyte application.
8. Configure Mobile/Desktop Flow (for flytectl):
* Go to the Authentication tab, and click + Add a platform.
* Select Mobile and desktop applications.
* Add following URI: `http://localhost:53593/callback`
* Scroll down to Advanced settings and enable Allow public client flows.
For further reference, check out the official [Entra ID Docs](https://docs.microsoft.com/en-us/power-apps/maker/portals/configure/configure-openid-settings) on how to configure the IdP for OpenIDConnect.
> Make sure the app is registered without [additional claims](https://docs.microsoft.com/en-us/power-apps/maker/portals/configure/configure-openid-settings#configure-additional-claims).
> **The OpenIDConnect authentication will not work otherwise**.
> Please refer to [this GitHub Issue](https://github.com/coreos/go-oidc/issues/215) and [Entra ID Docs](https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-protocols-oidc#sample-response) for more information.
### Apply the OIDC configuration to the Flyte backend
Select the Helm chart you used to install Flyte:
#### flyte-binary
1. Generate a random password to be used internally by `flytepropeller`
2. Use the following command to hash the password:
```shell
$ pip install bcrypt && python -c 'import bcrypt; import base64; print(base64.b64encode(bcrypt.hashpw("".encode("utf-8"), bcrypt.gensalt(6))))'
```
3. Go to your values file and locate the `auth` section and replace values accordingly:
```yaml
auth:
enabled: true
oidc:
# baseUrl: https://accounts.google.com # Uncomment for Google
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta use the Issuer URI from Okta's default auth server
baseUrl: https://dev-.okta.com/oauth2/default
# Replace with the client ID and secret created for Flyte in your IdP
clientId:
clientSecret:
internal:
clientSecret: ''
# Use the output of step #2 (only the content inside of '')
clientSecretHash:
authorizedUris:
- https://
```
4. Save your changes
5. Upgrade your Helm release with the new values:
```shell
$ helm upgrade flyteorg/flyte-binary -n --values .yaml
```
Where `` is the name of your Helm release, typically `flyte-backend`. You can find it using `helm ls -n `
6. Verify that your Flyte deployment now requires successful login to your IdP to access the UI (`https:///console`)
#### flyte-core
1. Generate a random password to be used internally by `flytepropeller`
2. Use the following command to hash the password:
```shell
$ pip install bcrypt && python -c 'import bcrypt; import base64; print(base64.b64encode(bcrypt.hashpw("".encode("utf-8"), bcrypt.gensalt(6))))'
```
Take note of the output (only the contents inside `''`).
3. Go to your Helm values file and add the client_secret provided by your IdP to the configuration:
```yaml
flyteadmin:
secrets:
oidc_client_secret:
```
4. Verify that the `configmap` section include the following, replacing the content where indicated:
```yaml
configmap:
adminServer:
server:
httpPort: 8088
grpc:
port: 8089
security:
secure: false
useAuth: true
allowCors: true
allowedOrigins:
# Accepting all domains for Sandbox installation
- "*"
allowedHeaders:
- "Content-Type"
auth:
appAuth:
thirdPartyConfig:
flyteClient:
clientId: flytectl
redirectUri: http://localhost:53593/callback
scopes:
- offline
- all
selfAuthServer:
staticClients:
flyte-cli:
id: flyte-cli
redirect_uris:
- http://localhost:53593/callback
- http://localhost:12345/callback
grant_types:
- refresh_token
- authorization_code
response_types:
- code
- token
scopes:
- all
- offline
- access_token
public: true
flytectl:
id: flytectl
redirect_uris:
- http://localhost:53593/callback
- http://localhost:12345/callback
grant_types:
- refresh_token
- authorization_code
response_types:
- code
- token
scopes:
- all
- offline
- access_token
public: true
flytepropeller:
id: flytepropeller
# Use the bcrypt hash generated for your random password
client_secret: ""
redirect_uris:
- http://localhost:3846/callback
grant_types:
- refresh_token
- client_credentials
response_types:
- token
scopes:
- all
- offline
- access_token
public: false
authorizedUris:
# Use the public URL of flyteadmin (a DNS record pointing to your Ingress resource)
- https://
- http://flyteadmin:80
- http://flyteadmin.flyte.svc.cluster.local:80
userAuth:
openId:
# baseUrl: https://accounts.google.com # Uncomment for Google
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the default auth server
baseUrl: https://dev-.okta.com/oauth2/default
# Use the client ID generated by your IdP
clientId:
scopes:
- profile
- openid
```
5. Additionally, at the root of the values file, add the following block and replace the necessary information:
```yaml
secrets:
adminOauthClientCredentials:
# If enabled is true, and `clientSecret` is specified, helm will create and mount `flyte-secret-auth`.
# If enabled is true, and `clientSecret` is null, it's up to the user to create `flyte-secret-auth` as described in
# https://docs.flyte.org/en/latest/deployment/cluster_config/auth_setup.html#oauth2-authorization-server
# and helm will mount `flyte-secret-auth`.
# If enabled is false, auth is not turned on.
# Note: Unsupported combination: enabled.false and clientSecret.someValue
enabled: true
# Use the non-encoded version of the random password
clientSecret: ""
clientId: flytepropeller
```
> For **Multi-cluster and multi-cloud** you must add this Secret definition block to the `values-dataplane.yaml` file. If you are not running `flytepropeller` in the control plane cluster, you do not need to create this secret there.
6. Save and exit your editor.
7. Upgrade your Helm release with the new configuration:
```shell
$ helm upgrade flyteorg/flyte-binary -n --values .yaml
```
8. Verify that the `flytepropeller`, `flytescheduler` and `flyteadmin` Pods are restarted and running:
```bash
kubectl get pods -n flyte
```
**Congratulations!**
It should now be possible to go to Flyte UI and be prompted for authentication with the default `PKCE` auth flow. Flytectl should automatically pickup the change and start prompting for authentication as well.
The following sections guide you to configure an external auth server (optional for most authorization flows) and describe the client-side configuration for all the auth flows supported by Flyte.
## Configuring your IdP as an External Authorization Server
In this section, you will find instructions on how to setup an OAuth2 Authorization Server in the different IdPs supported by Flyte:
### Okta
Okta's custom authorization servers are available through an add-on license. The free developer accounts do include access, which you can use to test before rolling out the configuration more broadly.
1. From the left-hand menu, go to **Security** > **API**
2. Click on **Add Authorization Server**.
3. Assign an informative name and set the audience to the public URL of FlyteAdmin (e.g. https://example.foobar.com). The audience must exactly match one of the URIs in the `authorizedUris` section above.
4. Note down the **Issuer URI**; this will be used for all the `baseUrl` settings in the Flyte config.
5. Go to **Scopes** and click **Add Scope**.
6. Set the name to `all` (required) and check `Required` under the **User consent** option.
7. Uncheck the **Block services from requesting this scope** option and save your changes.
8. Add another scope, named `offline`. Check both the **Required** and **Include in public metadata** options.
9. Uncheck the **Block services from requesting this scope** option.
10. Click **Save**.
11. Go to **Access Policies**, click **Add New Access Policy**. Enter a name and description and enable **Assign to** - `All clients`.
12. Add a rule to the policy with the default settings (you can fine-tune these later).
13. Navigate back to the **Applications** section.
14. Create an integration for `flytectl`; it should be created with the **OIDC - OpenID Connect** sign-on method, and the **Native Application** type.
15. Add `http://localhost:53593/callback` to the sign-in redirect URIs. The other options can remain as default.
16. Assign this integration to any Okta users or groups who should be able to use the `flytectl` tool.
17. Note down the **Client ID**; there will not be a secret.
18. Create an integration for `flytepropeller`; it should be created with the **OIDC - OpenID Connect** sign-on method and **Web Application** type.
19. Check the `Client Credentials` option under **Client acting on behalf of itself**.
20. This app does not need a specific redirect URI; nor does it need to be assigned to any users.
21. Note down the **Client ID** and **Client secret**; you will need these later.
22. Take note of the **Issuer URI** for your Authorization Server. It will be used as the baseURL parameter in the Helm chart
You should have three integrations total - one for the web interface (`flyteconsole`), one for `flytectl`, and one for `flytepropeller`.
### Keycloak
1. Create a realm in keycloak installation using its [admin console](https://wjw465150.gitbooks.io/keycloak-documentation/content/server_admin/topics/realms/create.html).
2. Under `Client Scopes`, click `Add Create` inside the admin console.
3. Create two clients (for `flytectl` and `flytepropeller`) to enable these clients to communicate with the service.
4. `flytectl` should be created with `Access Type Public` and standard flow enabled.
5. `flytePropeller` should be created as an `Access Type Confidential`, enabling the standard flow
6. Take note of the client ID and client Secrets provided.
### Microsoft Entra ID
1. Navigate to tab **Overview**, obtain `` and ``
2. Navigate to tab **Authentication**, click `+Add a platform`
3. Add **Web** for flyteconsole and flytepropeller, **Mobile and desktop applications** for flytectl.
4. Add URL `https:///callback` as the callback for Web
5. Add URL `http://localhost:53593/callback` as the callback for flytectl
6. In **Advanced settings**, set `Enable the following mobile and desktop flows` to **Yes** to enable deviceflow
7. Navigate to tab **Certificates & secrets**, click `+New client secret` to create ``
8. Navigate to tab **Token configuration**, click `+Add optional claim` and create email claims for both ID and Access Token
9. Navigate to tab **API permissions**, add `email`, `offline_access`, `openid`, `profile`, `User.Read`
10. Navigate to tab **Expose an API**, Click `+Add a scope` and `+Add a client application` to create ``.
### Apply the external auth server configuration to Flyte
Follow the steps in this section to configure `flyteadmin` to use an external auth server. This section assumes that you have already completed and applied the configuration for the OIDC Identity Layer.
#### flyte-binary
1. Go to the values YAML file you used to install Flyte
2. Find the `auth` section and follow the inline comments to insert your configuration:
```yaml
auth:
enabled: true
oidc:
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the custom auth server:
baseUrl: https://dev-.okta.com/oauth2/
# Use the client ID and secret generated by your IdP for the first OIDC registration in the "Identity Management layer : OIDC" section of this guide
clientId:
clientSecret:
internal:
# Use the clientID generated by your IdP for the flytepropeller app registration
clientId:
#Use the secret generated by your IdP for flytepropeller
clientSecret: ''
# Use the bcrypt hash for the clientSecret
clientSecretHash: <-flytepropeller-secret-bcrypt-hash>
authorizedUris:
# Use here the exact same value used for 'audience' when the Authorization server was configured
- https://
```
3. Find the `inline` section of the values file and add the following content, replacing where needed:
```yaml
inline:
auth:
appAuth:
authServerType: External
externalAuthServer:
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the custom auth server:
baseUrl: https://dev-.okta.com/oauth2/
metadataUrl: .well-known/oauth-authorization-server
thirdPartyConfig:
flyteClient:
# Use the clientID generated by your IdP for the `flytectl` app registration
clientId:
redirectUri: http://localhost:53593/callback
scopes:
- offline
- all
userAuth:
openId:
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the custom auth server:
baseUrl: https://dev-.okta.com/oauth2/
scopes:
- profile
- openid
# - offline_access # Uncomment if your IdP supports issuing refresh tokens (optional)
# Use the client ID and secret generated by your IdP for the first OIDC registration in the "Identity Management layer : OIDC" section of this guide
clientId:
```
4. Save your changes
5. Upgrade your Helm release with the new configuration:
```bash
helm upgrade flyteorg/flyte-core -n --values .yaml
```
#### flyte-core
1. Find the `auth` section in your Helm values file, and replace the necessary data:
> If you were previously using the internal auth server, make sure to delete all the `selfAuthServer` section from your values file
```yaml
configmap:
adminServer:
auth:
appAuth:
authServerType: External
# 2. Optional: Set external auth server baseUrl if different from OpenId baseUrl.
externalAuthServer:
# Replace this with your deployment URL. It will be used by flyteadmin to validate the token audience
allowedAudience: https://
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the custom auth server:
baseUrl: https://dev-.okta.com/oauth2/
metadataUrl: .well-known/openid-configuration
userAuth:
openId:
# baseUrl: https:///auth/realms/ # Uncomment for Keycloak and update with your installation host and realm name
# baseUrl: https://login.microsoftonline.com//v2.0 # Uncomment for Azure AD
# For Okta, use the Issuer URI of the custom auth server:
baseUrl: https://dev-.okta.com/oauth2/
scopes:
- profile
- openid
# - offline_access # Uncomment if OIdC supports issuing refresh tokens.
clientId:
secrets:
adminOauthClientCredentials:
enabled: true # see the section "Disable Helm secret management" if you require to do so
# Replace with the client_secret provided by your IdP for flytepropeller.
clientSecret:
# Replace with the client_id provided by provided by your IdP for flytepropeller.
clientId:
```
2. Save your changes
3. Upgrade your Helm release with the new configuration:
```bash
helm upgrade