# Data input/output
> This bundle contains all pages in the Data input/output section.
> Source: https://www.union.ai/docs/v1/union/user-guide/data-input-output/

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output ===

# Data input/output

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

<!-- TODO: Double check this seciton for variant acccuracy -->

This section covers how to manage data input and output in Union.ai.
Union.ai also supports all the [Data input/output features of Flyte](https://docs-builder.pages.dev/docs/flyte/user-guide/data-input-output/).

| Section | Description |
|---------------------------------------------------|----------------------------------------------------|
| **Data input/output > FlyteFile and FlyteDirectory** | Use `FlyteFile` to easily pass files across tasks. |
| **Data input/output > FlyteFile and FlyteDirectory** | Use `FlyteDirectory` to easily pass directories across tasks. |
| **Data input/output > Downloading with FlyteFile and FlyteDirectory** | Details on how files and directories or downloaded with `FlyteFile`. |
| **Data input/output > StructuredDataset** | Details on how `StructuredDataset`is used as a general dataframe type. |
| **Data input/output > Dataclass** | Details on how to uses dataclasses across tasks. |
| **Data input/output > Pydantic BaseModel** | Details on how to use pydantic models across tasks. |
| **Data input/output > Accessing attributes** | Details on how to directly access attributes on output promises for |
| **Data input/output > Enum type** | Details on how use Enums across tasks. |
| **Data input/output > Pickle type** | Details on how use pickled objects across tasks for generalized typ |
| **Data input/output > PyTorch type** | Details on how use torch tensors and models across tasks. |
| **Data input/output > TensorFlow types** | Details on how use tensorflow tensors and models across tasks. |
| **Data input/output > Accelerated datasets** | Upload your data once and access it from any task. |

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/flyte-file-and-flyte-directory ===

<!-- TODO: CHeck for variant accuracy  remove mention of flytesnacks-->

# FlyteFile and FlyteDirectory

In Union.ai, each task runs in its own container. This means that a file or directory created locally in one task will not automatically be available in other tasks.

The natural way to solve this problem is for the source task to upload the file or directory to a common location (like the Union.ai object store) and then pass a reference to that location to the destination task, which then downloads or streams the data.

Since this is such a common use case, the Union SDK provides the [`FlyteFile`](https://www.union.ai/docs/v1/union/api-reference/flytekit-sdk/packages/flytekit.types.file.file) and [`FlyteDirectory`](https://www.union.ai/docs/v1/union/api-reference/flytekit-sdk/packages/flytekit.types.directory.types) classes, which automate this process.

## How the classes work

The classes work by wrapping a file or directory location path and, if necessary, maintaining the persistence of the referenced file or directory across task containers.

When you return a `FlyteFile` (or `FlyteDirectory`) object from a task, Union.ai checks to see if the underlying file or directory is local to the task container or if it already exists in a remote location.

If it is local to the source container, then Union.ai automatically uploads it to an object store so that it is not lost when the task container is discarded on task completion.
If the file or directory is already remote, then no upload is performed.

When the `FlyteFile` (or `FlyteDirectory`) is passed into the next task, the location of the source file (or directory) is available within the object and it can be downloaded or streamed.

## Local examples

> [!NOTE] Local means local to the container
> The terms _local file_ and _local_directory_ in this section refer to a file or directory local to the container running a task in Union.ai.
> They do not refer to a file or directory on your local machine.

### Local file example

Let's say you have a local file in the container running `task_1` that you want to make accessible in the next task, `task_2`.
To do this, you create a `FlyteFile` object using the local path of the file, and then pass the `FlyteFile` object as part of your workflow, like this:

```python
@union.task
def task_1() -> union.FlyteFile:
    local_path = os.path.join(current_context().working_directory, "data.txt")
    with open(local_path, mode="w") as f:
        f.write("Here is some sample data.")
    return union.FlyteFile(path=local_path)

@union.task
def task_2(ff: union.FlyteFile):
    with ff.open(mode="r") as f
    file_contents = f.read()

@union.workflow
def wf():
    ff = task_1()
    task_2(ff=ff)
```

Union.ai handles the passing of the `FlyteFile` `ff` in the workflow `wf` from `task_1` to `task_2`:

* The `FlyteFile` object is initialized with the path (local to the `task_1` container) of the file you wish to share.
* When the `FlyteFile` is passed out of `task_1`, Union.ai uploads the local file to a unique location in the Union.ai object store. A randomly generated, universally unique location is used to ensure that subsequent uploads of other files never overwrite each other.
* The object store location is used to initialize the URI attribute of a Flyte `Blob` object. Note that Flyte objects are not Python objects. They exist at the workflow level and are used to pass data between task containers. For more details, see [Flyte Core Language Specification > Literals](https://www.union.ai/docs/v1/union/api-reference/flyteidl).
* The `Blob` object is passed to `task_2`.
* Because the type of the input parameter of `task_2` is `FlyteFile`, Union.ai converts the `Blob` back into a `FlyteFile` and sets the `remote_source` attribute of that `FlyteFile` to the URI of the `Blob` object.
* Inside `task_2` you can now perform a [`FlyteFile.open()`](https://www.union.ai/docs/v1/union/api-reference/flytekit-sdk/packages/flytekit.types.file.file) and read the file contents.

### Local directory example

Below is an equivalent local example for `FlyteDirectory`. The process of passing the `FlyteDirectory` between tasks is essentially identical to the `FlyteFile` example above.

```python
@union.task
def task1() -> union.FlyteDirectory: # Create new local directory
    p = os.path.join(current_context().working_directory, "my_new_directory")
    os.makedirs(p)

    # Create and write to two files
    with open(os.path.join(p, "file_1.txt"), 'w') as file1:
        file1.write("This is file 1.")
    with open(os.path.join(p, "file_2.txt"), 'w') as file2:
        file2.write("This is file 2.")

    return union.FlyteDirectory(p)

@union.task
def task2(fd: union.FlyteDirectory): # Get a list of the directory contents using os to return strings
    items = os.listdir(fd)
    print(type(items[0]))

    # Get a list of the directory contents using FlyteDirectory to return FlyteFiles
    files = union.FlyteDirectory.listdir(fd)
    print(type(files[0]))
    with open(files[0], mode="r") as f:
        d = f.read()
    print(f"The first line in the first file is: {d}")

@union.workflow
def workflow():
    fd = task1()
    task2(fd=fd)
```

## Changing the data upload location

> With Union.ai BYOC, the upload location is configurable.

By default, Union.ai uploads local files or directories to the default **raw data store** (Union.ai's dedicated internal object store).
However, you can change the upload location by setting the raw data prefix to your own bucket or specifying the `remote_path` for a `FlyteFile` or `FlyteDirectory`.

> [!NOTE] Setting up your own object store bucket
> For details on how to set up your own object store bucket, consult the direction for your cloud provider:
>
> * [Enabling AWS S3](https://www.union.ai/docs/v1/union/user-guide/deployment/enabling-aws-resources/enabling-aws-s3)
> * [Enabling Google Cloud Storage](https://www.union.ai/docs/v1/union/user-guide/deployment/enabling-gcp-resources/enabling-google-cloud-storage)
> * [Enabling Azure Blob Storage](https://www.union.ai/docs/v1/union/user-guide/deployment/enabling-azure-resources/enabling-azure-blob-storage)

### Changing the raw data prefix

If you would like files or directories to be uploaded to your own bucket, you can specify the AWS, GCS, or Azure bucket in the **raw data prefix** parameter at the workflow level on registration or per execution on the command line or in the UI.
This setting can be done at the workflow level on registration or per execution on the command line or in the UI.

<!-- TODO See [Raw data prefix]() for more information. -->

Union.ai will create a directory with a unique, random name in your bucket for each `FlyteFile` or `FlyteDirectory` data write to guarantee that you never overwrite your data.

### Specifying `remote_path` for a `FlyteFile` or `FlyteDirectory`

If you specify the `remote_path` when initializing your `FlyteFile` (or `FlyteDirectory`), the underlying data is written to that precise location with no randomization.

> [!NOTE] Using remote_path will overwrite data
> If you set `remote_path` to a static string, subsequent runs of the same task will overwrite the file.
> If you want to use a dynamically generated path, you will have to generate it yourself.

## Remote examples

### Remote file example

In the example above, we started with a local file.
To preserve that file across the task boundary, Union.ai uploaded it to the Union.ai object store before passing it to the next task.

You can also _start with a remote file_, simply by initializing the `FlyteFile` object with a URI pointing to a remote source. For example:

```python
@union.task
def task_1() -> union.FlyteFile:
    remote_path = "https://people.sc.fsu.edu/~jburkardt/data/csv/biostats.csv"
    return union.FlyteFile(path=remote_path)
```

In this case, no uploading is needed because the source file is already in a remote location.
When the object is passed out of the task, it is converted into a `Blob` with the remote path as the URI.
After the `FlyteFile` is passed to the next task, you can call `FlyteFile.open()` on it, just as before.

If you don't intend on passing the `FlyteFile` to the next task, and rather intend to open the contents of the remote file within the task, you can use `from_source`.

```python
@union.task
def load_json():
    uri = "gs://my-bucket/my-directory/example.json"
    my_json = FlyteFile.from_source(uri)

    # Load the JSON file into a dictionary and print it
    with open(my_json, "r") as json_file:
        data = json.load(json_file)
    print(data)
```

When initializing a `FlyteFile` with a remote file location, all URI schemes supported by `fsspec` are supported, including `http`, `https`(Web), `gs` (Google Cloud Storage), `s3` (AWS S3), `abfs`, and `abfss` (Azure Blob Filesystem).

### Remote directory example

Below is an equivalent remote example for `FlyteDirectory`. The process of passing the `FlyteDirectory` between tasks is essentially identical to the `FlyteFile` example above.

```python
@union.task
def task1() -> union.FlyteDirectory:
    p = "https://people.sc.fsu.edu/~jburkardt/data/csv/"
    return union.FlyteDirectory(p)

@union.task
def task2(fd: union.FlyteDirectory): # Get a list of the directory contents and display the first csv
    files = union.FlyteDirectory.listdir(fd)
    with open(files[0], mode="r") as f:
    d = f.read()
    print(f"The first csv is: \n{d}")

@union.workflow
def workflow():
    fd = task1()
    task2(fd=fd)
```

## Streaming

In the above examples, we showed how to access the contents of `FlyteFile` by calling `FlyteFile.open()`.
The object returned by `FlyteFile.open()` is a stream. In the above examples, the files were small, so a simple `read()` was used.
But for large files, you can iterate through the contents of the stream:

```python
@union.task
def task_1() -> union.FlyteFile:
    remote_path = "https://sample-videos.com/csv/Sample-Spreadsheet-100000-rows.csv"
    return union.FlyteFile(path=remote_path)

@union.task
def task_2(ff: union.FlyteFile):
    with ff.open(mode="r") as f
    for row in f:
        do_something(row)
```

## Downloading

Alternative, you can download the contents of a `FlyteFile` object to a local file in the task container.
There are two ways to do this: **implicitly** and **explicitly**.

### Implicit downloading

The source file of a `FlyteFile` object is downloaded to the local container file system automatically whenever a function is called that takes the `FlyteFile` object and then calls `FlyteFile`'s `__fspath__()` method.

`FlyteFile` implements the `os.PathLike` interface and therefore the `__fspath__()` method.
`FlyteFile`'s implementation of `__fspath__()` performs a download of the source file to the local container storage and returns the path to that local file.
This enables many common file-related operations in Python to be performed on the `FlyteFile` object.

The most prominent example of such an operation is calling Python's built-in `open()` method with a `FlyteFile`:

```python
@union.task
def task_2(ff: union.FlyteFile):
    with open(ff, mode="r") as f
    file_contents= f.read()
```

> [!NOTE] open() vs ff.open()
> Note the difference between
>
> `ff.open(mode="r")`
>
> and
>
> `open(ff, mode="r")`
>
> The former calls the `FlyteFile.open()` method and returns an iterator without downloading the file.
> The latter calls the built-in Python function `open()`, downloads the specified `FlyteFile` to the local container file system,
> and returns a handle to that file.
>
> Many other Python file operations (essentially, any that accept an `os.PathLike` object) can also be performed on a `FlyteFile`
> object and result in an automatic download.
>
> See [Downloading with FlyteFile and FlyteDirectory](./downloading-with-ff-and-fd) for more information.

### Explicit downloading

You can also explicitly download a `FlyteFile` to the local container file system by calling `FlyteFile.download()`:

```python
@union.task
def task_2(ff: union.FlyteFile):
    local_path = ff.download()
```

This method is typically used when you want to download the file without immediately reading it.

<!-- TODO: Explain:
classmethod FlyteFile.from_source()
classmethod FlyteFile.new_remote_file()

classmethod FlyteDirectory.from_source()
classmethod FlyteDirectory.listdir()
classmethod FlyteDirectory.new_remote()

FlyteDirectory.crawl()
FlyteDirectory.new_dir()
FlyteDirectory.new_file()
-->

## Typed aliases

The [Union SDK](https://www.union.ai/docs/v1/union/api-reference/union-sdk) defines some aliases of `FlyteFile` with specific type annotations.
Specifically, `FlyteFile` has the following [aliases for specific file types](https://www.union.ai/docs/v1/union/api-reference/flytekit-sdk/packages/flytekit.types.file.file):

* `HDF5EncodedFile`
* `HTMLPage`
* `JoblibSerializedFile`
* `JPEGImageFile`
* `PDFFile`
* `PNGImageFile`
* `PythonPickledFile`
* `PythonNotebook`
* `SVGImageFile`

Similarly, `FlyteDirectory` has the following [aliases](https://www.union.ai/docs/v1/union/api-reference/flytekit-sdk/packages/flytekit.types.directory.types):

* `TensorboardLogs`
* `TFRecordsDirectory`

These aliases can optionally be used when handling a file or directory of the specified type, although the object itself will still be a `FlyteFile` or `FlyteDirectory`.
The aliased versions of the classes are syntactic markers that enforce agreement between type annotations in the signatures of task functions, but they do not perform any checks on the actual contents of the file.

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/downloading-with-ff-and-fd ===

# Downloading with FlyteFile and FlyteDirectory

The basic idea behind `FlyteFile` and `FlyteDirectory` is that they represent files and directories in remote storage.
When you work with these objects in your tasks, you are working with references to the remote files and directories.

Of course, at some point you will need to access the actual contents of these files and directories,
which means that they have to be downloaded to the local file system of the task container.

The actual files and directories of a `FlyteFile` or `FlyteDirectory` are downloaded to the local file system of the task container in two ways:
* Explicitly, through a call to the `download` method.
* Implicitly, through automatic downloading.
  This occurs when an external function is called on the `FlyteFile` or `FlyteDirectory` that itself calls the `__fspath__` method.

To write efficient and performant task and workflow code, it is particularly important to have a solid understanding of when exactly downloading occurs.
Let's look at some examples showing when the content `FlyteFile` objects and `FlyteDirectory` objects are downloaded to the local task container file system.

## FlyteFile

**Calling `download` on a FlyteFile**

```python
@union.task
def my_task(ff: FlyteFile):
    print(os.path.isfile(ff.path))  # This will print False as nothing has been downloaded
    ff.download()
    print(os.path.isfile(ff.path))  # This will print True as the FlyteFile was downloaded
```

Note that we use `ff.path` which is of type `typing.Union[str, os.PathLike]` rather than using `ff` in `os.path.isfile` directly.
In the next example, we will see that using `os.path.isfile(ff)` invokes `__fspath__` which downloads the file.

**Implicit downloading by `__fspath__`**

In order to make use of some functions like `os.path.isfile` that you may be used to using with regular file paths, `FlyteFile`
implements a `__fspath__` method that downloads the remote contents to the `path` of `FlyteFile` local to the container.

```python
@union.task
def my_task(ff: FlyteFile):
    print(os.path.isfile(ff.path))  # This will print False as nothing has been downloaded
    print(os.path.isfile(ff))  # This will print True as os.path.isfile(ff) downloads via __fspath__
    print(os.path.isfile(ff.path))  # This will again print True as the file was downloaded
```

It is important to be aware of any operations on your `FlyteFile` that might call `__fspath__` and result in downloading.
Some examples include, calling `open(ff, mode="r")` directly on a `FlyteFile` (rather than on the `path` attribute) to get the contents of the path,
or similarly calling `shutil.copy` or `pathlib.Path` directly on a `FlyteFile`.

## FlyteDirectory

**Calling `download` on a FlyteDirectory**

```python
@union.task
def my_task(fd: FlyteDirectory):
    print(os.listdir(fd.path))  # This will print nothing as the directory has not been downloaded
    fd.download()
    print(os.listdir(fd.path))  # This will print the files present in the directory as it has been downloaded
```

Similar to how the `path` argument was used above for the `FlyteFile`, note that we use `fd.path` which is of type `typing.Union[str, os.PathLike]` rather than using `fd` in `os.listdir` directly.
Again, we will see that this is because of the invocation of `__fspath__` when `os.listdir(fd)` is called.

**Implicit downloading by `__fspath__`**

In order to make use of some functions like `os.listdir` that you may be used to using with directories, `FlyteDirectory`
implements a `__fspath__` method that downloads the remote contents to the `path` of `FlyteDirectory` local to the container.

```python
@union.task
def my_task(fd: FlyteDirectory):
    print(os.listdir(fd.path))  # This will print nothing as the directory has not been downloaded
    print(os.listdir(fd))  # This will print the files present in the directory as os.listdir(fd) downloads via __fspath__
    print(os.listdir(fd.path))  # This will again print the files present in the directory as it has been downloaded
```

It is important to be aware of any operations on your `FlyteDirectory` that might call `__fspath__` and result in downloading.
Some other examples include, calling `os.stat` directly on a `FlyteDirectory` (rather than on the `path` attribute) to get the status of the path,
or similarly calling `os.path.isdir` on a `FlyteDirectory` to check if a directory exists.

**Inspecting the contents of a directory without downloading using `crawl`**

As we saw above, using `os.listdir` on a `FlyteDirectory` to view the contents in remote blob storage
results in the contents being downloaded to the task container. If this should be avoided, the `crawl` method offers a means of inspecting
the contents of the directory without calling `__fspath__` and therefore downloading the directory contents.

```python
@union.task
def task1() -> FlyteDirectory:
    p = os.path.join(current_context().working_directory, "my_new_directory")
    os.makedirs(p)

    # Create and write to two files
    with open(os.path.join(p, "file_1.txt"), 'w') as file1:
        file1.write("This is file 1.")
    with open(os.path.join(p, "file_2.txt"), 'w') as file2:
        file2.write("This is file 2.")

    return FlyteDirectory(p)

@union.task
def task2(fd: FlyteDirectory):
    print(os.listdir(fd.path))  # This will print nothing as the directory has not been downloaded
    print(list(fd.crawl()))  # This will print the files present in the remote blob storage
    # e.g. [('s3://union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80', 'file_1.txt'), ('s3://union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80', 'file_2.txt')]
    print(list(fd.crawl(detail=True)))  # This will print the files present in the remote blob storage with details including type, the time it was created, and more
    # e.g. [('s3://union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80', {'file_1.txt': {'Key': 'union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80/file_1.txt', 'LastModified': datetime.datetime(2024, 7, 9, 16, 16, 21, tzinfo=tzlocal()), 'ETag': '"cfb2a3740155c041d2c3e13ad1d66644"', 'Size': 15, 'StorageClass': 'STANDARD', 'type': 'file', 'size': 15, 'name': 'union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80/file_1.txt'}}), ('s3://union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80', {'file_2.txt': {'Key': 'union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80/file_2.txt', 'LastModified': datetime.datetime(2024, 7, 9, 16, 16, 21, tzinfo=tzlocal()), 'ETag': '"500d703f270d4bc034e159480c83d329"', 'Size': 15, 'StorageClass': 'STANDARD', 'type': 'file', 'size': 15, 'name': 'union-contoso/ke/fe503def6ebe04fa7bba-n0-0/160e7266dcaffe79df85489771458d80/file_2.txt'}})]
    print(os.listdir(fd.path))  # This will again print nothing as the directory has not been downloaded
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/task-input-and-output ===

# Task input and output

The Union.ai workflow engine automatically manages the passing of data from task to task, and to the workflow output.

This mechanism relies on enforcing strong typing of task function parameters and return values.
This enables the workflow engine to efficiently marshall and unmarshall values from one task container to the next.

The actual data is temporarily stored in Union.ai's internal object store within your data plane (AWS S3, Google Cloud Storage, or Azure Blob Storage, depending on your cloud provider).

## Metadata and raw data

Union.ai distinguishes between the metadata and raw data.

Primitive values (`int`, `str`, etc.) are stored directly in the metadata store, while complex data objects (`pandas.DataFrame`, `FlyteFile`, etc.) are stored by reference, with the reference pointer in the metadata store and the actual data in the raw data store.

## Metadata store

The metadata store is located in the dedicated Union.ai object store in your data plane.
Depending on your cloud provider, this may be an AWS S3, Google Cloud Storage, or Azure Blob Storage bucket.

This data is accessible to the control plane. It is used to run and manage workflows and is surfaced in the UI.

## Raw data store

The raw data store is, by default, also located in the dedicated Union.ai object store in your data plane.

However, this location can be overridden per workflow or per execution using the **raw data prefix** parameter.

The data in the raw data store is not accessible to the control plane and will only be surfaced in the UI if your code explicitly does so (for example, in a Deck).

<!-- TODO: incorporate the referenced page here -->
For more details, see [Understand How Flyte Handles Data](https://www.union.ai/docs/v1/union/architecture/data-handling).

## Changing the raw data storage location

There are a number of ways to change the raw data location:

* When registering your workflow:
  * With [`uctl register`](), use the flag `--files.outputLocationPrefix`.
  * With [`union register`](), use the flag `--raw-data-prefix`.
* At the execution level:
  * In the UI, set the **Raw output data config** parameter in the execution dialog.

These options change the raw data location for **all large types** (`FlyteFile`, `FlyteDirectory`, `DataFrame`, any other large data object).

If you are only concerned with controlling where raw data used by `FlyteFile` or `FlyteDirectory` is stored, you can [set the `remote_path` parameter](./flyte-file-and-flyte-directory#specifying-remote_path-for-a-flytefile-or-flytedirectory) in your task code when initializing objects of those types.

### Setting up your own object store

By default, when Union.ai marshalls values across tasks, it stores both metadata and raw data in its own dedicated object store bucket.
While this bucket is located in your Union.ai BYOC data plane and is therefore under your control, it is part of the Union.ai implementation and should not be accessed or modified directly by your task code.

When changing the default raw data location, the target should therefore be a bucket that you set up, separate from the Union.ai-implemented bucket.

For information on setting up your own bucket and enabling access to it, see [Enabling AWS S3](../integrations/enabling-aws-resources/enabling-aws-s3), [Enabling Google Cloud Storage](../integrations/enabling-gcp-resources/enabling-google-cloud-storage), or [Enabling Azure Blob Storage](../integrations/enabling-azure-resources/enabling-azure-blob-storage), depending on your cloud provider.

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/accelerated-datasets ===

# Accelerated datasets

> [!NOTE] *Accelerated datasets* and *Accelerators* are entirely different things
> Accelerated datasets is a Union.ai feature that enables quick access to large datasets from within a task.
> An [accelerator](https://www.union.ai/docs/v1/union/user-guide/core-concepts/tasks/task-hardware-environment/accelerators), on the other hand, is a specialized hardware device that is used to accelerate the execution of a task.
> These concepts are entirely different and should not be confused.

Many of the workflows that you may want to run in Union.ai will involve tasks that use large static assets such as reference genomes, training datasets, or pre-trained models.
These assets are often stored in an object store and need to be downloaded to the task pod each time before the task can run.
This can be a significant bottleneck, especially if the data must be loaded into memory to be randomly accessed and therefore cannot be streamed.

To remedy this, Union.ai provides a way to preload large static assets into a shared object store that is mounted to all machine nodes in your cluster by default.
This allows you to upload your data once and then access it from any task without needing to download it each time.

Data items stored in this way are called *accelerated datasets*.

> [!NOTE] Only on S3
> Currently, this feature is only available for AWS S3.

## How it works

* Each customer has a dedicated S3 bucket where they can store their accelerated datasets.
* The naming and set up of this bucket must be coordinated with the Union.ai team, in order that a suitable name is chosen. In general it will usually be something like `s3://union-<org-name>-persistent`.
* You can upload any data you wish to this bucket.
* The bucket will be automatically mounted into every node in your cluster.
* To your task logic, it will appear to be a local directory in the task container.
* To use it, initialize a `FlyteFile` object with the path to the data file and pass it into a task as an input.
    * Note that in order for the system to recognize the file as an accelerated dataset, it must be created as a `FlyteFile` and that `FLyteFile` must be passed *into* a task.
      If you try to access the file directly from the object store, it will not be recognized as an accelerated dataset and the data will not be found.

## Example usage

Assuming that your organization is called `my-company` and the file you want to access is called `my_data.csv`, you would first need to upload the file to the persistent bucket. See [Upload a File to Your Amazon S3 Bucket](https://docs.aws.amazon.com/quickstarts/latest/s3backup/step-2-upload-file.html).

The code to access the data looks like this:

```python
import union

@union.task
def my_task(f: union.FlyteFile) -> int:
    with open(f, newline="\n") as input_file:
    data = input_file.read()
    # Do something with the data

@union.workflow
def my_wf()
    my_task(f=union.FlyteFile("s3://union-my-company-persistent/my_data.csv"))
```

Note that you do not have to invoke `FlyteFile.download()` because the file will already have been made available locally within the container.

## Considerations

### Caching

While the persistent bucket appears to your task as a locally mounted volume, the data itself will not be resident in the local file system until after the first access. After the first access it will be cached locally. This fact should be taken into account when using this feature.

### Storage consumption

Data cached during the use of accelerated datasets will consume local storage on the nodes in your cluster. This should be taken into account when selecting and sizing your cluster nodes.

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/accessing-attributes ===

# Accessing attributes

You can directly access attributes on output promises for lists, dictionaries, dataclasses, and combinations of these types in Union.ai.
Note that while this functionality may appear to be the normal behavior of Python, code in `@workflow` functions is not actually Python, but rather a Python-like DSL that is compiled by Union.ai.
Consequently, accessing attributes in this manner is, in fact, a specially implemented feature.
This functionality facilitates the direct passing of output attributes within workflows, enhancing the convenience of working with complex data structures.

To begin, import the required dependencies and define a common task for subsequent use:

```python
from dataclasses import dataclass
import union

@union.task
def print_message(message: str):
    print(message)
    return
```

## List
You can access an output list using index notation.

> [!NOTE]
> Union.ai currently does not support output promise access through list slicing.

```python
@union.task
def list_task() -> list[str]:
    return ["apple", "banana"]

@union.workflow
def list_wf():
    items = list_task()
    first_item = items[0]
    print_message(message=first_item)
```

## Dictionary
Access the output dictionary by specifying the key.

```python
@union.task
def dict_task() -> dict[str, str]:
    return {"fruit": "banana"}

@union.workflow
def dict_wf():
    fruit_dict = dict_task()
    print_message(message=fruit_dict["fruit"])
```

## Data class
Directly access an attribute of a dataclass.

```python
@dataclass
class Fruit:
    name: str

@union.task
def dataclass_task() -> Fruit:
    return Fruit(name="banana")

@union.workflow
def dataclass_wf():
    fruit_instance = dataclass_task()
    print_message(message=fruit_instance.name)
```

## Complex type
Combinations of list, dict and dataclass also work effectively.

```python
@union.task
def advance_task() -> (dict[str, list[str]], list[dict[str, str]], dict[str, Fruit]):
    return {"fruits": ["banana"]}, [{"fruit": "banana"}], {"fruit": Fruit(name="banana")}

@union.task
def print_list(fruits: list[str]):
    print(fruits)

@union.task
def print_dict(fruit_dict: dict[str, str]):
    print(fruit_dict)

@union.workflow
def advanced_workflow():
    dictionary_list, list_dict, dict_dataclass = advance_task()
    print_message(message=dictionary_list["fruits"][0])
    print_message(message=list_dict[0]["fruit"])
    print_message(message=dict_dataclass["fruit"].name)

    print_list(fruits=dictionary_list["fruits"])
    print_dict(fruit_dict=list_dict[0])
```

You can run all the workflows locally as follows:

```python
if __name__ == "__main__":
    list_wf()
    dict_wf()
    dataclass_wf()
    advanced_workflow()
```

## Failure scenario
The following workflow fails because it attempts to access indices and keys that are out of range:

```python
from flytekit import WorkflowFailurePolicy

@union.task
def failed_task() -> (list[str], dict[str, str], Fruit):
    return ["apple", "banana"], {"fruit": "banana"}, Fruit(name="banana")

@union.workflow(
    # The workflow remains unaffected if one of the nodes encounters an error, as long as other executable nodes are still available
    failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
)
def failed_workflow():
    fruits_list, fruit_dict, fruit_instance = failed_task()
    print_message(message=fruits_list[100])  # Accessing an index that doesn't exist
    print_message(message=fruit_dict["fruits"])  # Accessing a non-existent key
    print_message(message=fruit_instance.fruit)  # Accessing a non-existent param
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/dataclass ===

<!-- TODO: check for variant accuracy, remove mention of flytesnacks figure out "UnionTypes" -->

# Dataclass

When you've multiple values that you want to send across Union.ai entities, you can use a `dataclass`.

To begin, import the necessary dependencies:

```python
import os
import tempfile
from dataclasses import dataclass

import pandas as pd
import union
from flytekit.types.structured import StructuredDataset
```

Build your custom image with ImageSpec:
```python
image_spec = union.ImageSpec(
    registry="ghcr.io/flyteorg",
    packages=["pandas", "pyarrow"],
)
```

## Python types
We define a `dataclass` with `int`, `str` and `dict` as the data types.

```python
@dataclass
class Datum:
    x: int
    y: str
    z: dict[int, str]
```

You can send a `dataclass` between different tasks written in various languages, and input it through the Union.ai UI as raw JSON.

> [!NOTE]
> All variables in a data class should be **annotated with their type**. Failure to do will result in an error.

Once declared, a dataclass can be returned as an output or accepted as an input.

```python
@union.task(container_image=image_spec)
def stringify(s: int) -> Datum:
    """
    A dataclass return will be treated as a single complex JSON return.
    """
    return Datum(x=s, y=str(s), z={s: str(s)})

@union.task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:
    x.z.update(y.z)
    return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)
```

## Union.ai types
We also define a data class that accepts `StructuredDataset`, `FlyteFile` and `FlyteDirectory`.

```python
@dataclass
class UnionTypes:
    dataframe: StructuredDataset
    file: union.FlyteFile
    directory: union.FlyteDirectory

@union.task(container_image=image_spec)
def upload_data() -> UnionTypes:
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

    temp_dir = tempfile.mkdtemp(prefix="union-")
    df.to_parquet(temp_dir + "/df.parquet")

    file_path = tempfile.NamedTemporaryFile(delete=False)
    file_path.write(b"Hello, World!")

    fs = UnionTypes(
        dataframe=StructuredDataset(dataframe=df),
        file=union.FlyteFile(file_path.name),
        directory=union.FlyteDirectory(temp_dir),
    )
    return fs

@union.task(container_image=image_spec)
def download_data(res: UnionTypes):
    assert pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}).equals(res.dataframe.open(pd.DataFrame).all())
    f = open(res.file, "r")
    assert f.read() == "Hello, World!"
    assert os.listdir(res.directory) == ["df.parquet"]
```

A data class supports the usage of data associated with Python types, data classes,
FlyteFile, FlyteDirectory and structured dataset.

We define a workflow that calls the tasks created above.

```python
@union.workflow
def dataclass_wf(x: int, y: int) -> (Datum, FlyteTypes):
    o1 = add(x=stringify(s=x), y=stringify(s=y))
    o2 = upload_data()
    download_data(res=o2)
    return o1, o2
```

To trigger the above task that accepts a dataclass as an input with `union run`, you can provide a JSON file as an input:

```shell
$ union run dataclass.py add --x dataclass_input.json --y dataclass_input.json
```

Here is another example of triggering a task that accepts a dataclass as an input with `union run`, you can provide a JSON file as an input:

```shell
$ union run \
  https://raw.githubusercontent.com/flyteorg/flytesnacks/69dbe4840031a85d79d9ded25f80397c6834752d/examples/data_types_and_io/data_types_and_io/dataclass.py \
  add --x dataclass_input.json --y dataclass_input.json
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/enum ===

# Enum type

At times, you might need to limit the acceptable values for inputs or outputs to a predefined set.
This common requirement is usually met by using `Enum` types in programming languages.

You can create a Python `Enum` type and utilize it as an input or output for a task.
Union will automatically convert it and constrain the inputs and outputs to the predefined set of values.

> [!NOTE]
> Currently, only string values are supported as valid `Enum` values.
> Union.ai assumes the first value in the list as the default, and `Enum` types cannot be optional.
> Therefore, when defining `Enum`s, it's important to design them with the first value as a valid default.

We define an `Enum` and a simple coffee maker workflow that accepts an order and brews coffee ☕️ accordingly.
The assumption is that the coffee maker only understands `Enum` inputs:

```python
# coffee_maker.py

from enum import Enum

import union

class Coffee(Enum):
    ESPRESSO = "espresso"
    AMERICANO = "americano"
    LATTE = "latte"
    CAPPUCCINO = "cappucccino"

@union.task
def take_order(coffee: str) -> Coffee:
    return Coffee(coffee)

@union.task
def prep_order(coffee_enum: Coffee) -> str:
    return f"Preparing {coffee_enum.value} ..."

@union.workflow
def coffee_maker(coffee: str) -> str:
    coffee_enum = take_order(coffee=coffee)
    return prep_order(coffee_enum=coffee_enum)

# The workflow can also accept an enum value
@union.workflow
def coffee_maker_enum(coffee_enum: Coffee) -> str:
    return prep_order(coffee_enum=coffee_enum)
```

You can specify value for the parameter `coffee_enum` on run:

```shell
$ union run coffee_maker.py coffee_maker_enum --coffee_enum="latte"
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/pickle ===

# Pickle type

Union.ai enforces type safety by utilizing type information for compiling tasks and workflows,
enabling various features such as static analysis and conditional branching.

However, we also strive to offer flexibility to end-users, so they don't have to invest heavily
in understanding their data structures upfront before experiencing the value Union.ai has to offer.

Union.ai supports the `FlytePickle` transformer, which converts any unrecognized type hint into `FlytePickle`,
enabling the serialization/deserialization of Python values to/from a pickle file.

> [!NOTE]
> Pickle can only be used to send objects between the exact same Python version.
> For optimal performance, it's advisable to either employ Python types that are supported by Union.ai
> or register a custom transformer, as using pickle types can result in lower performance.

This example demonstrates how you can utilize custom objects without registering a transformer.

```python
import union
```

`Superhero` represents a user-defined complex type that can be serialized to a pickle file by Union
and transferred between tasks as both input and output data.

> [!NOTE]
> Alternatively, you can [turn this object into a dataclass](./dataclass) for improved performance.
> We have used a simple object here for demonstration purposes.

```python
class Superhero:
    def __init__(self, name, power):
        self.name = name
        self.power = power

@union.task
def welcome_superhero(name: str, power: str) -> Superhero:
    return Superhero(name, power)

@union.task
def greet_superhero(superhero: Superhero) -> str:
    return f"👋 Hello {superhero.name}! Your superpower is {superhero.power}."

@union.workflow
def superhero_wf(name: str = "Thor", power: str = "Flight") -> str:
    superhero = welcome_superhero(name=name, power=power)
    return greet_superhero(superhero=superhero)
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/pydantic ===

# Pydantic BaseModel

<!-- TODO: check for variant accuracy figure out UnionTypes-->

> [!NOTE]
> You can put Dataclass and UnionTypes (FlyteFile, FlyteDirectory, FlyteSchema, and StructuredDataset) in a pydantic BaseModel.

<!-- TODO: check above for variant accuracy -->

To begin, import the necessary dependencies:

```python
import os
import tempfile
import pandas as pd
from union
from union.types.structured import StructuredDataset
from pydantic import BaseModel
```

Build your custom image with ImageSpec:
```python
image_spec = union.ImageSpec(
    registry="ghcr.io/flyteorg",
    packages=["pandas", "pyarrow", "pydantic"],
)
```

## Python types
We define a `pydantic basemodel` with `int`, `str` and `dict` as the data types.

```python
class Datum(BaseModel):
    x: int
    y: str
    z: dict[int, str]
```

You can send a `pydantic basemodel` between different tasks written in various
languages, and input it through the Union.ai console as raw
JSON.

> [!NOTE]
> All variables in a data class should be **annotated with their type**. Failure
> to do will result in an error.

Once declared, a dataclass can be returned as an output or accepted as an input.

```python
@union.task(container_image=image_spec)
def stringify(s: int) -> Datum:
    """
    A Pydantic model return will be treated as a single complex JSON return.
    """
    return Datum(x=s, y=str(s), z={s: str(s)})

@union.task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:
    x.z.update(y.z)
    return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)
```

## Union.ai types

We also define a data class that accepts `StructuredDataset`, `FlyteFile` and
`FlyteDirectory`.

```python
class UnionTypes(BaseModel):
    dataframe: StructuredDataset
    file: union.FlyteFile
    directory: union.FlyteDirectory

@union.task(container_image=image_spec)
def upload_data() -> FlyteTypes:
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

    temp_dir = tempfile.mkdtemp(prefix="flyte-")
    df.to_parquet(os.path.join(temp_dir, "df.parquet"))

    file_path = tempfile.NamedTemporaryFile(delete=False)
    file_path.write(b"Hello, World!")
    file_path.close()

    fs = FlyteTypes(
        dataframe=StructuredDataset(dataframe=df),
        file=union.FlyteFile(file_path.name),
        directory=union.FlyteDirectory(temp_dir),
    )
    return fs

@union.task(container_image=image_spec)
def download_data(res: FlyteTypes):
    expected_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
    actual_df = res.dataframe.open(pd.DataFrame).all()
    assert expected_df.equals(actual_df), "DataFrames do not match!"

    with open(res.file, "r") as f:
        assert f.read() == "Hello, World!", "File contents do not match!"

    assert os.listdir(res.directory) == ["df.parquet"], "Directory contents do not match!"
```

A data class supports the usage of data associated with Python types, data
classes, FlyteFile, FlyteDirectory and StructuredDataset.

We define a workflow that calls the tasks created above.

```python
@union.workflow
def basemodel_wf(x: int, y: int) -> tuple[Datum, UnionTypes]:
    o1 = add(x=stringify(s=x), y=stringify(s=y))
    o2 = upload_data()
    download_data(res=o2)
    return o1, o2
```

To trigger a task that accepts a dataclass as an input with `union run`, you can provide a JSON file as an input:

```
$ union run dataclass.py basemodel_wf --x 1 --y 2
```

To trigger a task that accepts a dataclass as an input with `union run`, you can provide a JSON file as an input:
```
union run \
  https://raw.githubusercontent.com/flyteorg/flytesnacks/b71e01d45037cea883883f33d8d93f258b9a5023/examples/data_types_and_io/data_types_and_io/pydantic_basemodel.py \
  basemodel_wf --x 1 --y 2
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/pytorch ===

# PyTorch type

Union.ai advocates for the use of strongly-typed data to simplify the development of robust and testable pipelines. In addition to its application in data engineering, Union.ai is primarily used for machine learning.
To streamline the communication between Union.ai tasks, particularly when dealing with tensors and models, we have introduced support for PyTorch types.

## Tensors and modules

At times, you may find the need to pass tensors and modules (models) within your workflow. Without native support for PyTorch tensors and modules, Union relies on [pickle](https://docs-builder.pages.dev/docs/byoc/user-guide/data-input-output/pickle/) for serializing and deserializing these entities, as well as any unknown types. However, this approach isn't the most efficient. As a result, we've integrated PyTorch's serialization and deserialization support into the Union.ai type system.

```python
@union.task
def generate_tensor_2d() -> torch.Tensor:
    return torch.tensor([[1.0, -1.0, 2], [1.0, -1.0, 9], [0, 7.0, 3]])

@union.task
def reshape_tensor(tensor: torch.Tensor) -> torch.Tensor:
    # convert 2D to 3D
    tensor.unsqueeze_(-1)
    return tensor.expand(3, 3, 2)

@union.task
def generate_module() -> torch.nn.Module:
    bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
    return bn

@union.task
def get_model_weight(model: torch.nn.Module) -> torch.Tensor:
    return model.weight

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

@union.task
def get_l1() -> torch.nn.Module:
    model = MyModel()
    return model.l1

@union.workflow
def pytorch_native_wf():
    reshape_tensor(tensor=generate_tensor_2d())
    get_model_weight(model=generate_module())
    get_l1()
```

Passing around tensors and modules is no more a hassle!

## Checkpoint

`PyTorchCheckpoint` is a specialized checkpoint used for serializing and deserializing PyTorch models.
It checkpoints `torch.nn.Module`'s state, hyperparameters and optimizer state.

This module checkpoint differs from the standard checkpoint as it specifically captures the module's `state_dict`.
Therefore, when restoring the module, the module's `state_dict` must be used in conjunction with the actual module.
According to the PyTorch [docs](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model),
it's recommended to store the module's `state_dict` rather than the module itself,
although the serialization should work in either case.

```python
from dataclasses import dataclass

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses_json import dataclass_json
from flytekit.extras.pytorch import PyTorchCheckpoint

@dataclass_json
@dataclass
class Hyperparameters:
    epochs: int
    loss: float

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

@union.task
def generate_model(hyperparameters: Hyperparameters) -> PyTorchCheckpoint:
    bn = Net()
    optimizer = optim.SGD(bn.parameters(), lr=0.001, momentum=0.9)
    return PyTorchCheckpoint(module=bn, hyperparameters=hyperparameters, optimizer=optimizer)

@union.task
def load(checkpoint: PyTorchCheckpoint):
    new_bn = Net()
    new_bn.load_state_dict(checkpoint["module_state_dict"])
    optimizer = optim.SGD(new_bn.parameters(), lr=0.001, momentum=0.9)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

@union.workflow
def pytorch_checkpoint_wf():
    checkpoint = generate_model(hyperparameters=Hyperparameters(epochs=10, loss=0.1))
    load(checkpoint=checkpoint)
```

> [!NOTE]
> `PyTorchCheckpoint` supports serializing hyperparameters of types `dict`, `NamedTuple` and `dataclass`.

## Auto GPU to CPU and CPU to GPU conversion

Not all PyTorch computations require a GPU. In some cases, it can be advantageous to transfer the
computation to a CPU, especially after training the model on a GPU.
To utilize the power of a GPU, the typical construct to use is: `to(torch.device("cuda"))`.

When working with GPU variables on a CPU, variables need to be transferred to the CPU using the `to(torch.device("cpu"))` construct.
However, this manual conversion recommended by PyTorch may not be very user-friendly.
To address this, we added support for automatic GPU to CPU conversion (and vice versa) for PyTorch types.

```python
import union
from typing import Tuple

@union.task(requests=union.Resources(gpu="1"))
def train() -> Tuple[PyTorchCheckpoint, torch.Tensor, torch.Tensor, torch.Tensor]:
    ...
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Model(X_train.shape[1])
    model.to(device)
    ...
    X_train, X_test = X_train.to(device), X_test.to(device)
    y_train, y_test = y_train.to(device), y_test.to(device)
    ...
    return PyTorchCheckpoint(module=model), X_train, X_test, y_test

@union.task
def predict(
    checkpoint: PyTorchCheckpoint,
    X_train: torch.Tensor,
    X_test: torch.Tensor,
    y_test: torch.Tensor,
):
    new_bn = Model(X_train.shape[1])
    new_bn.load_state_dict(checkpoint["module_state_dict"])

    accuracy_list = np.zeros((5,))

    with torch.no_grad():
        y_pred = new_bn(X_test)
        correct = (torch.argmax(y_pred, dim=1) == y_test).type(torch.FloatTensor)
        accuracy_list = correct.mean()
```

The `predict` task will run on a CPU, and
the device conversion from GPU to CPU will be automatically handled by Union.

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/structured-dataset ===

# StructuredDataset

As with most type systems, Python has primitives, container types like maps and tuples, and support for user-defined structures. However, while there’s a rich variety of DataFrame classes (Pandas, Spark, Pandera, etc.), there’s no native Python type that represents a DataFrame in the abstract. This is the gap that the `StructuredDataset` type is meant to fill. It offers the following benefits:

- Eliminate boilerplate code you would otherwise need to write to serialize/deserialize from file objects into DataFrame instances,
- Eliminate additional inputs/outputs that convey metadata around the format of the tabular data held in those files,
- Add flexibility around how DataFrame files are loaded,
- Offer a range of DataFrame specific functionality - enforce compatibility of different schemas
  (not only at compile time, but also runtime since type information is carried along in the literal),
   store third-party schema definitions, and potentially in the future, render sample data, provide summary statistics, etc.

## Usage

To use the `StructuredDataset` type, import `pandas` and define a task that returns a Pandas Dataframe.
Union will detect the Pandas DataFrame return signature and convert the interface for the task to
the `StructuredDataset` type.

## Example

This example demonstrates how to work with a structured dataset using Union.ai entities.

> [!NOTE]
> To use the `StructuredDataset` type, you only need to import `pandas`. The other imports specified below are only necessary for this specific example.

To begin, import the dependencies for the example:

```python
import typing
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import union
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.types.structured.structured_dataset import (
    PARQUET,
    StructuredDataset,
    StructuredDatasetDecoder,
    StructuredDatasetEncoder,
    StructuredDatasetTransformerEngine,
)
from typing_extensions import Annotated
```

Define a task that returns a Pandas DataFrame.

```python
@union.task(container_image=image_spec)
def generate_pandas_df(a: int) -> pd.DataFrame:
    return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [a, 22], "Height": [160, 178]})
```

Using this simplest form, however, the user is not able to set the additional DataFrame information alluded to above,

- Column type information
- Serialized byte format
- Storage driver and location
- Additional third party schema information

This is by design as we wanted the default case to suffice for the majority of use-cases, and to require
as few changes to existing code as possible. Specifying these is simple, however, and relies on Python variable annotations,
which is designed explicitly to supplement types with arbitrary metadata.

## Column type information
If you want to extract a subset of actual columns of the DataFrame and specify their types for type validation,
you can just specify the column names and their types in the structured dataset type annotation.

First, initialize column types you want to extract from the `StructuredDataset`.

```python
all_cols = union.kwtypes(Name=str, Age=int, Height=int)
col = union.kwtypes(Age=int)
```

Define a task that opens a structured dataset by calling `all()`.
When you invoke `all()` with ``pandas.DataFrame``, the Union.ai engine downloads the parquet file on S3, and deserializes it to `pandas.DataFrame`.
Keep in mind that you can invoke ``open()`` with any DataFrame type that's supported or added to structured dataset.
For instance, you can use ``pa.Table`` to convert the Pandas DataFrame to a PyArrow table.

```python
@union.task(container_image=image_spec)
def get_subset_pandas_df(df: Annotated[StructuredDataset, all_cols]) -> Annotated[StructuredDataset, col]:
    df = df.open(pd.DataFrame).all()
    df = pd.concat([df, pd.DataFrame([[30]], columns=["Age"])])
    return StructuredDataset(dataframe=df)

@union.workflow
def simple_sd_wf(a: int = 19) -> Annotated[StructuredDataset, col]:
    pandas_df = generate_pandas_df(a=a)
    return get_subset_pandas_df(df=pandas_df)
```

The code may result in runtime failures if the columns do not match.
The input ``df`` has ``Name``, ``Age`` and ``Height`` columns, whereas the output structured dataset will only have the ``Age`` column.

## Serialized byte format
You can use a custom serialization format to serialize your DataFrames.
Here's how you can register the Pandas to CSV handler, which is already available,
and enable the CSV serialization by annotating the structured dataset with the CSV format:

```python
from flytekit.types.structured import register_csv_handlers
from flytekit.types.structured.structured_dataset import CSV

register_csv_handlers()

@union.task(container_image=image_spec)
def pandas_to_csv(df: pd.DataFrame) -> Annotated[StructuredDataset, CSV]:
    return StructuredDataset(dataframe=df)

@union.workflow
def pandas_to_csv_wf() -> Annotated[StructuredDataset, CSV]:
    pandas_df = generate_pandas_df(a=19)
    return pandas_to_csv(df=pandas_df)
```

## Storage driver and location

By default, the data will be written to the same place that all other pointer-types (FlyteFile, FlyteDirectory, etc.) are written to.
This is controlled by the output data prefix option in Union.ai which is configurable on multiple levels.

That is to say, in the simple default case, Union will,

- Look up the default format for say, Pandas DataFrames,
- Look up the default storage location based on the raw output prefix setting,
- Use these two settings to select an encoder and invoke it.

So what's an encoder? To understand that, let's look into how the structured dataset plugin works.

## Inner workings of a structured dataset plugin

Two things need to happen with any DataFrame instance when interacting with Union.ai:

- Serialization/deserialization from/to the Python instance to bytes (in the format specified above).
- Transmission/retrieval of those bits to/from somewhere.

Each structured dataset plugin (called encoder or decoder) needs to perform both of these steps.
Union decides which of the loaded plugins to invoke based on three attributes:

- The byte format
- The storage location
- The Python type in the task or workflow signature.

These three keys uniquely identify which encoder (used when converting a DataFrame in Python memory to a Union.ai value,
e.g. when a task finishes and returns a DataFrame) or decoder (used when hydrating a DataFrame in memory from a Union.ai value,
e.g. when a task starts and has a DataFrame input) to invoke.

However, it is awkward to require users to use `typing.Annotated` on every signature.
Therefore, Union has a default byte-format for every registered Python DataFrame type.

## The `uri` argument

BigQuery `uri` allows you to load and retrieve data from cloud using the `uri` argument.
The `uri` comprises of the bucket name and the filename prefixed with `gs://`.
If you specify BigQuery `uri` for structured dataset, BigQuery creates a table in the location specified by the `uri`.
The `uri` in structured dataset reads from or writes to S3, GCP, BigQuery or any storage.

Before writing DataFrame to a BigQuery table,

1. Create a [GCP account](https://cloud.google.com/docs/authentication/getting-started) and create a service account.
2. Create a project and add the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to your `.bashrc` file.
3. Create a dataset in your project.

Here's how you can define a task that converts a pandas DataFrame to a BigQuery table:

```python
@union.task
def pandas_to_bq() -> StructuredDataset:
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
    return StructuredDataset(dataframe=df, uri="gs://<BUCKET_NAME>/<FILE_NAME>")
```

Replace `BUCKET_NAME` with the name of your GCS bucket and `FILE_NAME` with the name of the file the DataFrame should be copied to.

### Note that no format was specified in the structured dataset constructor, or in the signature. So how did the BigQuery encoder get invoked?
This is because the stock BigQuery encoder is loaded into Union with an empty format.
The Union `StructuredDatasetTransformerEngine` interprets that to mean that it is a generic encoder
(or decoder) and can work across formats, if a more specific format is not found.

And here's how you can define a task that converts the BigQuery table to a pandas DataFrame:

```python
@union.task
def bq_to_pandas(sd: StructuredDataset) -> pd.DataFrame:
   return sd.open(pd.DataFrame).all()
```

> [!NOTE]
> Union.ai creates a table inside the dataset in the project upon BigQuery query execution.

## How to return multiple DataFrames from a task?
For instance, how would a task return say two DataFrames:
- The first DataFrame be written to BigQuery and serialized by one of their libraries,
- The second needs to be serialized to CSV and written at a specific location in GCS different from the generic pointer-data bucket

If you want the default behavior (which is itself configurable based on which plugins are loaded),
you can work just with your current raw DataFrame classes.

```python
@union.task
def t1() -> typing.Tuple[StructuredDataset, StructuredDataset]:
   ...
   return StructuredDataset(df1, uri="bq://project:flyte.table"), \
          StructuredDataset(df2, uri="gs://auxiliary-bucket/data")
```

If you want to customize the Union.ai interaction behavior, you'll need to wrap your DataFrame in a `StructuredDataset` wrapper object.

## How to define a custom structured dataset plugin?

`StructuredDataset` ships with an encoder and a decoder that handles the conversion of a
Python value to a Union.ai literal and vice-versa, respectively.
Here is a quick demo showcasing how one might build a NumPy encoder and decoder,
enabling the use of a 2D NumPy array as a valid type within structured datasets.

### NumPy encoder

Extend `StructuredDatasetEncoder` and implement the `encode` function.
The `encode` function converts NumPy array to an intermediate format (parquet file format in this case).

```python
class NumpyEncodingHandler(StructuredDatasetEncoder):
    def encode(
        self,
        ctx: union.FlyteContext,
        structured_dataset: StructuredDataset,
        structured_dataset_type: union.StructuredDatasetType,
    ) -> literals.StructuredDataset:
        df = typing.cast(np.ndarray, structured_dataset.dataframe)
        name = ["col" + str(i) for i in range(len(df))]
        table = pa.Table.from_arrays(df, name)
        path = ctx.file_access.get_random_remote_directory()
        local_dir = ctx.file_access.get_random_local_directory()
        local_path = Path(local_dir) / f"{0:05}"
        pq.write_table(table, str(local_path))
        ctx.file_access.upload_directory(local_dir, path)
        return literals.StructuredDataset(
            uri=path,
            metadata=StructuredDatasetMetadata(structured_dataset_type=union.StructuredDatasetType(format=PARQUET)),
        )
```

<!-- TODO: clean up code -->
### NumPy decoder

Extend `StructuredDatasetDecoder` and implement the `StructuredDatasetDecoder.decode` function.
The `StructuredDatasetDecoder.decode` function converts the parquet file to a `numpy.ndarray`.

```python
class NumpyDecodingHandler(StructuredDatasetDecoder):
    def decode(
        self,
        ctx: union.FlyteContext,
        flyte_value: literals.StructuredDataset,
        current_task_metadata: StructuredDatasetMetadata,
    ) -> np.ndarray:
        local_dir = ctx.file_access.get_random_local_directory()
        ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)
        table = pq.read_table(local_dir)
        return table.to_pandas().to_numpy()
```

### NumPy renderer

Create a default renderer for numpy array, then Union will use this renderer to
display schema of NumPy array on the Deck.

```python
class NumpyRenderer:
    def to_html(self, df: np.ndarray) -> str:
        assert isinstance(df, np.ndarray)
        name = ["col" + str(i) for i in range(len(df))]
        table = pa.Table.from_arrays(df, name)
        return pd.DataFrame(table.schema).to_html(index=False)
```

In the end, register the encoder, decoder and renderer with the `StructuredDatasetTransformerEngine`.
Specify the Python type you want to register this encoder with (`np.ndarray`),
the storage engine to register this against (if not specified, it is assumed to work for all the storage backends),
and the byte format, which in this case is `PARQUET`.

```python
StructuredDatasetTransformerEngine.register(NumpyEncodingHandler(np.ndarray, None, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandler(np.ndarray, None, PARQUET))
StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer())
```

You can now use `numpy.ndarray` to deserialize the parquet file to NumPy and serialize a task's output (NumPy array) to a parquet file.

```python
@union.task(container_image=image_spec)
def generate_pd_df_with_str() -> pd.DataFrame:
    return pd.DataFrame({"Name": ["Tom", "Joseph"]})

@union.task(container_image=image_spec)
def to_numpy(sd: StructuredDataset) -> Annotated[StructuredDataset, None, PARQUET]:
    numpy_array = sd.open(np.ndarray).all()
    return StructuredDataset(dataframe=numpy_array)

@union.workflow
def numpy_wf() -> Annotated[StructuredDataset, None, PARQUET]:
    return to_numpy(sd=generate_pd_df_with_str())
```

> [!NOTE]
> `pyarrow` raises an `Expected bytes, got a 'int' object` error when the DataFrame contains integers.

You can run the code locally as follows:

```python
if __name__ == "__main__":
    sd = simple_sd_wf()
    print(f"A simple Pandas DataFrame workflow: {sd.open(pd.DataFrame).all()}")
    print(f"Using CSV as the serializer: {pandas_to_csv_wf().open(pd.DataFrame).all()}")
    print(f"NumPy encoder and decoder: {numpy_wf().open(np.ndarray).all()}")
```

### The nested typed columns

Like most storage formats (e.g. Avro, Parquet, and BigQuery), StructuredDataset support nested field structures.

```python
data = [
    {
        "company": "XYZ pvt ltd",
        "location": "London",
        "info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}},
    },
    {
        "company": "ABC pvt ltd",
        "location": "USA",
        "info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}},
    },
]

@dataclass
class ContactsField:
    email: str
    tel: str

@dataclass
class InfoField:
    president: str
    contacts: ContactsField

@dataclass
class CompanyField:
    location: str
    info: InfoField
    company: str

MyArgDataset = Annotated[StructuredDataset, union.kwtypes(company=str)]
MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField]
MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}]

MyDictDataset = Annotated[StructuredDataset, union.kwtypes(info={"contacts": {"tel": str}})]
MyDictListDataset = Annotated[StructuredDataset, union.kwtypes(info={"contacts": {"tel": str, "email": str}})]
MySecondDataClassDataset = Annotated[StructuredDataset, union.kwtypes(info=InfoField)]
MyNestedDataClassDataset = Annotated[StructuredDataset, union.kwtypes(info=union.kwtypes(contacts=ContactsField))]

image = union.ImageSpec(packages=["pandas", "pyarrow", "pandas", "tabulate"], registry="ghcr.io/flyteorg")

@union.task(container_image=image)
def create_parquet_file() -> StructuredDataset:
    from tabulate import tabulate

    df = pd.json_normalize(data, max_level=0)
    print("original DataFrame: \n", tabulate(df, headers="keys", tablefmt="psql"))

    return StructuredDataset(dataframe=df)

@union.task(container_image=image)
def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyArgDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyDictDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyDictListDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyTopDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyTopDictDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MySecondDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.task(container_image=image)
def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame:
    from tabulate import tabulate

    t = sd.open(pd.DataFrame).all()
    print("MyNestedDataClassDataset DataFrame: \n", tabulate(t, headers="keys", tablefmt="psql"))
    return t

@union.workflow
def contacts_wf():
    sd = create_parquet_file()
    print_table_by_arg(sd=sd)
    print_table_by_dict(sd=sd)
    print_table_by_list_dict(sd=sd)
    print_table_by_top_dataclass(sd=sd)
    print_table_by_top_dict(sd=sd)
    print_table_by_second_dataclass(sd=sd)
    print_table_by_nested_dataclass(sd=sd)
```

=== PAGE: https://www.union.ai/docs/v1/union/user-guide/data-input-output/tensorflow ===

# TensorFlow types

This document outlines the TensorFlow types available in Union.ai, which facilitate the integration of TensorFlow models and datasets in Union.ai workflows.

### Import necessary libraries and modules
```python
import union
from flytekit.types.directory import TFRecordsDirectory
from flytekit.types.file import TFRecordFile

custom_image = union.ImageSpec(
    packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],
    registry="ghcr.io/flyteorg",
)

import tensorflow as tf
```

## Tensorflow model
Union.ai supports the TensorFlow SavedModel format for serializing and deserializing `tf.keras.Model` instances. The `TensorFlowModelTransformer` is responsible for handling these transformations.

### Transformer
- **Name:** TensorFlow Model
- **Class:** `TensorFlowModelTransformer`
- **Python Type:** `tf.keras.Model`
- **Blob Format:** `TensorFlowModel`
- **Dimensionality:** `MULTIPART`

### Usage
The `TensorFlowModelTransformer` allows you to save a TensorFlow model to a remote location and retrieve it later in your Union.ai workflows.

```python
@union.task(container_image=custom_image)
def train_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")]
    )
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

@union.task(container_image=custom_image)
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:
    loss, accuracy = model.evaluate(x, y)
    return accuracy

@union.workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:
    model = train_model()
    return evaluate_model(model=model, x=x, y=y)
```

## TFRecord files
Union.ai supports TFRecord files through the `TFRecordFile` type, which can handle serialized TensorFlow records. The `TensorFlowRecordFileTransformer` manages the conversion of TFRecord files to and from Union.ai literals.

### Transformer
- **Name:** TensorFlow Record File
- **Class:** `TensorFlowRecordFileTransformer`
- **Blob Format:** `TensorFlowRecord`
- **Dimensionality:** `SINGLE`

### Usage
The `TensorFlowRecordFileTransformer` enables you to work with single TFRecord files, making it easy to read and write data in TensorFlow's TFRecord format.

```python
@union.task(container_image=custom_image)
def process_tfrecord(file: TFRecordFile) -> int:
    count = 0
    for record in tf.data.TFRecordDataset(file):
        count += 1
    return count

@union.workflow
def tfrecord_workflow(file: TFRecordFile) -> int:
    return process_tfrecord(file=file)
```

## TFRecord directories
Union.ai supports directories containing multiple TFRecord files through the `TFRecordsDirectory` type. The `TensorFlowRecordsDirTransformer` manages the conversion of TFRecord directories to and from Union.ai literals.

### Transformer
- **Name:** TensorFlow Record Directory
- **Class:** `TensorFlowRecordsDirTransformer`
- **Python Type:** `TFRecordsDirectory`
- **Blob Format:** `TensorFlowRecord`
- **Dimensionality:** `MULTIPART`

### Usage
The `TensorFlowRecordsDirTransformer` allows you to work with directories of TFRecord files, which is useful for handling large datasets that are split across multiple files.

#### Example
```python
@union.task(container_image=custom_image)
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
    count = 0
    for record in tf.data.TFRecordDataset(dir.path):
        count += 1
    return count

@union.workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
    return process_tfrecords_dir(dir=dir)
```

## Configuration class: `TFRecordDatasetConfig`
The `TFRecordDatasetConfig` class is a data structure used to configure the parameters for creating a `tf.data.TFRecordDataset`, which allows for efficient reading of TFRecord files. This class uses the `DataClassJsonMixin` for easy JSON serialization.

### Attributes
- **compression_type**: (Optional) Specifies the compression method used for the TFRecord files. Possible values include an empty string (no compression), "ZLIB", or "GZIP".
- **buffer_size**: (Optional) Defines the size of the read buffer in bytes. If not set, defaults will be used based on the local or remote file system.
- **num_parallel_reads**: (Optional) Determines the number of files to read in parallel. A value greater than one outputs records in an interleaved order.
- **name**: (Optional) Assigns a name to the operation for easier identification in the pipeline.

This configuration is crucial for optimizing the reading process of TFRecord datasets, especially when dealing with large datasets or when specific performance tuning is required.

