David Espejo
Len Strnad

Building Large-Scale Xarray Datasets for Geospatial Computing with Union.ai and Flyte

As geospatial machine learning engineers, we often face a common challenge: how do we effectively build, manage, and process large-scale geospatial datasets that can reach terabytes or even petabytes in size? Traditional approaches using notebooks or simple scripts quickly break down at scale, and manually babysitting these processes becomes unsustainable.

This post explores how to build scalable mosaic workflows using Union.ai and the Flyte orchestration platform combined with GDAL, Xarray, and Dask. This approach has been tested at multiple organizations, enabling us to build scalable, end-to-end remote sensing pipelines. Teams across data-intensive industries like geospatial, autonomous vehicles, and biotech rely on pipelines like these for mission-critical projects. Let’s learn how.

The Foundation: Standards for Remote Sensing and Machine Learning

Before diving into the technical implementation, let's establish some standards that make large-scale geospatial computing more manageable:

1. Data Models with Xarray

Xarray provides an excellent foundation for modeling mosaic datasets. Think of it as representing one large image with multiple dimensions:

  • Dimensions: time, bands, y, x
  • Coordinates: timestamps, band names, latitude, longitude
  • Data: The actual pixel values
Copied to clipboard!
# Example of an Xarray dataset structure
# dimensions: (time, band, y, x)
# coordinates:
#   * time: array of datetime64 values
#   * band: array of strings like 'red', 'green', 'blue', 'nir'
#   * y: array of y coordinates (can be latitude)
#   * x: array of x coordinates (can be longitude)

This data model allows us to:

  • Represent arbitrarily large arrays
  • Chunk data in optimal ways for processing
  • Lazy-load specific regions of interest
  • Apply operations to the entire dataset

2. Compute with Dask

Dask works exceptionally well with Xarray for distributed computing, particularly for:

  • Operations on reasonably sized and properly chunked datasets
  • Short-lived tasks like interpolation
  • Distributing work across clusters

However, Dask has limitations:

  • Very small chunks can create too many tasks
  • It's not a workflow orchestration tool
  • You still need something to manage the overall process

3. Mosaics with GDAL

For building mosaics (stitching together multiple scenes into one large image), GDAL offers powerful tools:

  • VRT (Virtual Raster Transform): Maps and combines multiple files with the same coordinate reference system
  • Warped VRT: Handles reprojection between different coordinate reference systems
  • GTI (GeoTIFF Tile Index): A newer approach that indexes rasters with their metadata for faster access

Implementing Scalable Workflows with Flyte and Union.ai

Now let's see how to combine these standards with Flyte orchestration to build truly scalable workflows. Flyte is the leading OSS orchestration tool from Union.ai

Building a Complete Mosaic Workflow

Let's look at a more complete example from the FlyteMosaic repo. The main workflow looks like this:

Copied to clipboard!
@workflow
def build_dataset_mosaic_workflow(
    bbox: list[float],
    times: list[datetime.datetime],
    datasets: list[DatasetEnum],
    resolution: float,
    crs: str,
    chunk_partition_size: int,
    xy_chunksize: int = 2048,
) -> str:
    scenes_gdf = ingest_scenes_workflow(
        bbox=bbox,
        times=times,
        datasets=datasets,
    )
    scene_features = build_scene_features_workflow(
        bbox=bbox,
        times=times,
        datasets=datasets,
    )
    scenes_gdf >> scene_features

    gdf_grouped, bounds = build_gti_inputs_task(gdf=scene_features, bounds=bbox)

    gtis = map_task(
        partial(
            gdf_to_gti_task,
            crs=crs,
            resolution=resolution,
        )
    )(gdf=gdf_grouped, bounds=bounds)

    store = build_target_mosaic_task(gtis=gtis, xy_chunksize=xy_chunksize)

    gti_partitions = build_gti_partitions_task(
        store=store,
        chunk_partition_size=chunk_partition_size,
        gtis=gtis,
    )

    map_task(
        partial(
            write_mosaic_partition_task,
            store=store,
            xy_chunksize=xy_chunksize,
        ),
        concurrency=32,
    )(gti_partition=gti_partitions)

    return store

This workflow consists of three main stages:

  1. Ingest Scenes: Find and download all necessary scenes covering our area of interest
  2. Build Scene Features: Process raw scenes into standardized, optimized COGs (Cloud-Optimized GeoTIFFs)
  3. Build Target Mosaic: Create the final mosaic dataset in Zarr format

The Power and Limits of GTI for Large Mosaics

For the mosaic building step, we can use GDAL's GTI driver to handle large collections of files:

Copied to clipboard!
def build_gti_xarray(
    gti: str,
    chunksize: int,
    band_names: list[str],
    resx: float | None = None,
    resy: float | None = None,
    time: datetime.datetime | None = None,
) -> xr.Dataset:

 # see https://gdal.org/en/latest/drivers/raster/gti.html#open-options
    open_kwargs = {}
    if resx is not None and resy is not None:
        open_kwargs = {"resx": resx, "resy": resy}

    da: xr.DataArray = rioxarray.open_rasterio(  # type: ignore  # noqa: PGH003
        gti,
        chunks=(1, chunksize, chunksize),
        lock=False,
        open_kwargs=open_kwargs,
    )
    da["band"] = band_names
    if time is not None:
        da = da.expand_dims("time")
        da["time"] = [time]
    return da.to_dataset(name="variables")

However, GTI doesn't support temporal dimensions natively, so we need to:

  1. Build one GTI per dataset per time
  2. Concatenate them along the time dimension using Xarray:
Copied to clipboard!
def build_temporal_mosaic(gti_mosaics: list[TemporalGTIMosaic]) -> xr.Dataset:
    feature_dsets = []
    for nm, grp in groupby(
        sorted(gti_mosaics, key=lambda x: x.dataset.name),
        key=lambda x: x.dataset,
    ):
        dp = get_dataset_protocol(dataset_enum=nm)
        feature_dsets.append(
            xr.concat(
                [
                    build_gti_xarray(
                        gti=gti.gti,
                        chunksize=gti.chunksize,
                        band_names=dp.bands,
                        time=gti.time,
                    )
                    for gti in grp
                ],
                dim="time",
            )
        )
    return (
        xr.concat(feature_dsets, dim="band").transpose("time", "band", "y", "x").chunk({"time": 1})
    )

Handling Very Large Datasets with Partitioning

For extremely large datasets that might not fit into memory or cause issues with Dask's task graph, we use a partitioning approach:

Copied to clipboard!
@task(
    cache=True,
    cache_version=_cache_version(2),
    requests=Resources(cpu="3", mem="8Gi"),
)
def build_gti_partitions_task(
    store: str, chunk_partition_size: int, gtis: list[GTIResult]
) -> list[GTIPartition]:
    ds = xr.open_zarr(store)
    gti_partitions = []
    for nm, grp in groupby(sorted(gtis, key=lambda x: x.dataset.name), lambda x: x.dataset):
        dp = get_dataset_protocol(dataset_enum=nm)
        partition_indices = list(
            build_mosaic_chunk_partitions(
                ds=ds.chunk({"time": 1}),
                chunk_partition_size=int(chunk_partition_size),
                variable_name="variables",
                bands=dp.bands,
            )
        )
        gti_time = {gti.time: gti for gti in grp}
        for partition in partition_indices:
            t = ds.time.isel(time=slice(*partition["time"])).data[0]
            gti = gti_time[pd.Timestamp(t).to_pydatetime()]
            gti_partitions.append(GTIPartition(gti=gti, partition=partition))
    shuffle(gti_partitions)
    return gti_partitions


@task(
    cache=True,
    cache_version=_cache_version(),
    requests=Resources(cpu="3", mem="8Gi", ephemeral_storage="16Gi"),
    environment=gdal_configs.get_worker_config(8, debug=True),
)
def write_mosaic_partition_task(
    gti_partition: GTIPartition,
    store: str,
    xy_chunksize: int,
) -> bool:
    # single threaded dask scheduler but ALL_CPUS for GDAL_NUM_THREADS in environment
    with dask.config.set(scheduler="single-threaded"):
        dp = get_dataset_protocol(dataset_enum=gti_partition.gti.dataset)
        ds_time = build_gti_xarray(
            gti=gti_partition.gti.gti.download(),
            chunksize=xy_chunksize,
            band_names=dp.bands,
            time=gti_partition.gti.time,
        )
        region = {k: slice(*v) for k, v in gti_partition.partition.items()}
        # we drop time and band since the gti is now partitioned on both
        subset_slices = {k: v for k, v in region.items() if k not in ["time", "band"]}
        subset = ds_time.isel(**subset_slices)  # type: ignore  # noqa: PGH003
        subset["variables"].attrs.clear()
        subset.drop("spatial_ref").to_zarr(store, region=region)
    return True

Then we can process these partitions in parallel using Flyte's map_task (complete code):

Copied to clipboard!
@workflow
def build_dataset_mosaic_workflow(
    bbox: list[float],
    times: list[datetime.datetime],
    datasets: list[DatasetEnum],
    resolution: float,
    crs: str,
    chunk_partition_size: int,
    xy_chunksize: int = 2048,
) -> str:
  ...

    gtis = map_task(
        partial(
            gdf_to_gti_task,
            crs=crs,
            resolution=resolution,
        )
    )(gdf=gdf_grouped, bounds=bounds)

    store = build_target_mosaic_task(gtis=gtis, xy_chunksize=xy_chunksize)

    gti_partitions = build_gti_partitions_task(
        store=store,
        chunk_partition_size=chunk_partition_size,
        gtis=gtis,
    )

    map_task(
        partial(
            write_mosaic_partition_task,
            store=store,
            xy_chunksize=xy_chunksize,
        ),
        concurrency=32,
    )(gti_partition=gti_partitions)

    return store

This approach has several key advantages:

  • We never materialize the entire dataset in memory, saving compute resources
  • We process chunks in a way that leverages GDAL's internal caching which speeds up further iterations
  • We can parallelize across partitions using Flyte's map_task without having to deal with other parallel computing frameworks
  • The final output is a single, coherent Zarr store/Xarray dataset

Key Lessons and Caveats

Based on experience implementing these workflows, here are some important considerations:

  1. Map Task Limits: Try to keep below 5,000 tasks, as each Flyte task corresponds to a Kubernetes pod, and there are limits on the object size K8’s database (etcd) supports. For higher limits, explore Union-specific Map over Launchplans
  2. GDAL Configuration: Proper GDAL configuration makes a huge difference in performance:
Copied to clipboard!
def get_worker_config(memory_gb: int, debug: bool = False) -> dict[str, str]:
    """
    Apply some heuristics to determine the optimal GDAL configuration for a worker.

    Parameters
    ----------
    memory_gb : int
        The amount of memory available to the worker in gigabytes.
    debug : bool, optional
        Whether to enable debug mode, by default False.

    Returns
    -------
    dict[str, str]
        A dictionary of GDAL configuration based on the input memory.
    """
    return {
        "GDAL_HTTP_MAX_RETRY": "20",
        "GDAL_HTTP_RETRY_DELAY": "30",
        "GDAL_HTTP_MERGE_CONSECUTIVE_RANGES": "YES",
        "GDAL_HTTP_MULTIPLEX": "YES",
        "GDAL_HTTP_VERSION": "2",
        "GDAL_DISABLE_READDIR_ON_OPEN": "TRUE",
        "CPL_VSIL_CURL_CACHE_SIZE": str(1024**3 * memory_gb * 1 / 3),
        "CPL_VSIL_CURL_CHUNK_SIZE": str(1024**2 * 12),
        "VSI_CACHE": "TRUE",
        "VSI_CACHE_SIZE": str(1024**3 * memory_gb * 1 / 3),
        "GDAL_CACHEMAX": str(1024**3 * memory_gb * 1 / 2),
        "GDAL_NUM_THREADS": "ALL_CPUS",
        "CPL_DEBUG": "ON" if debug else "OFF",
        "CPL_CURL_VERBOSE": "YES" if debug else "NO",
    }
  1. Single-threaded Partition Processing: Use single-threaded Dask scheduler for partition processing to better leverage GDAL's internal caching
  2. Chunk Size Considerations: Balance chunk sizes to avoid memory issues while minimizing the total number of tasks
  3. GTI Metadata: Provide complete metadata in your GTI files to avoid GDAL needing to fetch it

Conclusion

By combining the power of Flyte and Union.ai's orchestration capabilities with geospatial tools like GDAL and data models like Xarray, we can build truly scalable workflows for processing massive geospatial datasets. This approach allows us to:

  • Process hundreds of terabytes of geospatial data
  • Build end-to-end pipelines that are reproducible and scalable
  • Separate concerns between data ingestion, processing, and analysis
  • Enable data scientists to focus on analysis rather than infrastructure

The code examples in this post are available in the FlyteMosaic repository, where you can find complete implementations of these patterns.

A recently contributed Flyte plugin enables you to easily persist Xarray datasets and dataarrays to Zarr between tasks.

Whether you're building canopy height models, soil carbon predictions, or any other large-scale geospatial application, these patterns can help you move from notebook experiments to production-ready data pipelines.

Next Steps

Start building your own geospatial workflows with Flyte and Union.ai today.

Geospatial
Data
Plugins