As geospatial machine learning engineers, we often face a common challenge: how do we effectively build, manage, and process large-scale geospatial datasets that can reach terabytes or even petabytes in size? Traditional approaches using notebooks or simple scripts quickly break down at scale, and manually babysitting these processes becomes unsustainable.
This post explores how to build scalable mosaic workflows using Union.ai and the Flyte orchestration platform combined with GDAL, Xarray, and Dask. This approach has been tested at multiple organizations, enabling us to build scalable, end-to-end remote sensing pipelines. Teams across data-intensive industries like geospatial, autonomous vehicles, and biotech rely on pipelines like these for mission-critical projects. Let’s learn how.
The Foundation: Standards for Remote Sensing and Machine Learning
Before diving into the technical implementation, let's establish some standards that make large-scale geospatial computing more manageable:
1. Data Models with Xarray
Xarray provides an excellent foundation for modeling mosaic datasets. Think of it as representing one large image with multiple dimensions:
Dimensions: time, bands, y, x
Coordinates: timestamps, band names, latitude, longitude
Data: The actual pixel values
Copied to clipboard!
# Example of an Xarray dataset structure
# dimensions: (time, band, y, x)
# coordinates:
# * time: array of datetime64 values
# * band: array of strings like 'red', 'green', 'blue', 'nir'
# * y: array of y coordinates (can be latitude)
# * x: array of x coordinates (can be longitude)
This data model allows us to:
Represent arbitrarily large arrays
Chunk data in optimal ways for processing
Lazy-load specific regions of interest
Apply operations to the entire dataset
2. Compute with Dask
Dask works exceptionally well with Xarray for distributed computing, particularly for:
Operations on reasonably sized and properly chunked datasets
Short-lived tasks like interpolation
Distributing work across clusters
However, Dask has limitations:
Very small chunks can create too many tasks
It's not a workflow orchestration tool
You still need something to manage the overall process
3. Mosaics with GDAL
For building mosaics (stitching together multiple scenes into one large image), GDAL offers powerful tools:
VRT (Virtual Raster Transform): Maps and combines multiple files with the same coordinate reference system
Warped VRT: Handles reprojection between different coordinate reference systems
GTI (GeoTIFF Tile Index): A newer approach that indexes rasters with their metadata for faster access
Implementing Scalable Workflows with Flyte and Union.ai
Now let's see how to combine these standards with Flyte orchestration to build truly scalable workflows. Flyte is the leading OSS orchestration tool from Union.ai.
Building a Complete Mosaic Workflow
Let's look at a more complete example from the FlyteMosaic repo. The main workflow looks like this:
For extremely large datasets that might not fit into memory or cause issues with Dask's task graph, we use a partitioning approach:
Copied to clipboard!
@task(
cache=True,
cache_version=_cache_version(2),
requests=Resources(cpu="3", mem="8Gi"),
)
def build_gti_partitions_task(
store: str, chunk_partition_size: int, gtis: list[GTIResult]
) -> list[GTIPartition]:
ds = xr.open_zarr(store)
gti_partitions = []
for nm, grp in groupby(sorted(gtis, key=lambda x: x.dataset.name), lambda x: x.dataset):
dp = get_dataset_protocol(dataset_enum=nm)
partition_indices = list(
build_mosaic_chunk_partitions(
ds=ds.chunk({"time": 1}),
chunk_partition_size=int(chunk_partition_size),
variable_name="variables",
bands=dp.bands,
)
)
gti_time = {gti.time: gti for gti in grp}
for partition in partition_indices:
t = ds.time.isel(time=slice(*partition["time"])).data[0]
gti = gti_time[pd.Timestamp(t).to_pydatetime()]
gti_partitions.append(GTIPartition(gti=gti, partition=partition))
shuffle(gti_partitions)
return gti_partitions
@task(
cache=True,
cache_version=_cache_version(),
requests=Resources(cpu="3", mem="8Gi", ephemeral_storage="16Gi"),
environment=gdal_configs.get_worker_config(8, debug=True),
)
def write_mosaic_partition_task(
gti_partition: GTIPartition,
store: str,
xy_chunksize: int,
) -> bool:
# single threaded dask scheduler but ALL_CPUS for GDAL_NUM_THREADS in environment
with dask.config.set(scheduler="single-threaded"):
dp = get_dataset_protocol(dataset_enum=gti_partition.gti.dataset)
ds_time = build_gti_xarray(
gti=gti_partition.gti.gti.download(),
chunksize=xy_chunksize,
band_names=dp.bands,
time=gti_partition.gti.time,
)
region = {k: slice(*v) for k, v in gti_partition.partition.items()}
# we drop time and band since the gti is now partitioned on both
subset_slices = {k: v for k, v in region.items() if k not in ["time", "band"]}
subset = ds_time.isel(**subset_slices) # type: ignore # noqa: PGH003
subset["variables"].attrs.clear()
subset.drop("spatial_ref").to_zarr(store, region=region)
return True
Then we can process these partitions in parallel using Flyte's map_task (complete code):
We never materialize the entire dataset in memory, saving compute resources
We process chunks in a way that leverages GDAL's internal caching which speeds up further iterations
We can parallelize across partitions using Flyte's map_task without having to deal with other parallel computing frameworks
The final output is a single, coherent Zarr store/Xarray dataset
Key Lessons and Caveats
Based on experience implementing these workflows, here are some important considerations:
Map Task Limits: Try to keep below 5,000 tasks, as each Flyte task corresponds to a Kubernetes pod, and there are limits on the object size K8’s database (etcd) supports. For higher limits, explore Union-specific Map over Launchplans
GDAL Configuration: Proper GDAL configuration makes a huge difference in performance:
Copied to clipboard!
def get_worker_config(memory_gb: int, debug: bool = False) -> dict[str, str]:
"""
Apply some heuristics to determine the optimal GDAL configuration for a worker.
Parameters
----------
memory_gb : int
The amount of memory available to the worker in gigabytes.
debug : bool, optional
Whether to enable debug mode, by default False.
Returns
-------
dict[str, str]
A dictionary of GDAL configuration based on the input memory.
"""
return {
"GDAL_HTTP_MAX_RETRY": "20",
"GDAL_HTTP_RETRY_DELAY": "30",
"GDAL_HTTP_MERGE_CONSECUTIVE_RANGES": "YES",
"GDAL_HTTP_MULTIPLEX": "YES",
"GDAL_HTTP_VERSION": "2",
"GDAL_DISABLE_READDIR_ON_OPEN": "TRUE",
"CPL_VSIL_CURL_CACHE_SIZE": str(1024**3 * memory_gb * 1 / 3),
"CPL_VSIL_CURL_CHUNK_SIZE": str(1024**2 * 12),
"VSI_CACHE": "TRUE",
"VSI_CACHE_SIZE": str(1024**3 * memory_gb * 1 / 3),
"GDAL_CACHEMAX": str(1024**3 * memory_gb * 1 / 2),
"GDAL_NUM_THREADS": "ALL_CPUS",
"CPL_DEBUG": "ON" if debug else "OFF",
"CPL_CURL_VERBOSE": "YES" if debug else "NO",
}
Single-threaded Partition Processing: Use single-threaded Dask scheduler for partition processing to better leverage GDAL's internal caching
Chunk Size Considerations: Balance chunk sizes to avoid memory issues while minimizing the total number of tasks
GTI Metadata: Provide complete metadata in your GTI files to avoid GDAL needing to fetch it
Conclusion
By combining the power of Flyte and Union.ai's orchestration capabilities with geospatial tools like GDAL and data models like Xarray, we can build truly scalable workflows for processing massive geospatial datasets. This approach allows us to:
Process hundreds of terabytes of geospatial data
Build end-to-end pipelines that are reproducible and scalable
Separate concerns between data ingestion, processing, and analysis
Enable data scientists to focus on analysis rather than infrastructure
The code examples in this post are available in the FlyteMosaic repository, where you can find complete implementations of these patterns.
Whether you're building canopy height models, soil carbon predictions, or any other large-scale geospatial application, these patterns can help you move from notebook experiments to production-ready data pipelines.