Samhita Alla

Parallel Audio Transcription: Using Whisper, JAX and Flyte Map Tasks for Streamlined Batch Inference

Imagine you have a dataset of thousands of audio files that need to be transcribed. Running the transcription process sequentially for each file can be time-consuming and resource-intensive. By leveraging parallel batch inference, you can significantly reduce the overall processing time by transcribing multiple files simultaneously. It also lets you leverage the power of modern hardware, such as GPUs, more efficiently. However, setting up and scaling the necessary infrastructure for parallel batch inference can be a complex challenge to overcome.

In this blog post, we’ll demonstrate how you can use map tasks to perform parallel batch inference efficiently within Union Cloud, the managed offering of Flyte. We’ll also present an experiment we conducted to showcase this approach, using a JAX Whisper model for audio transcription. The end result will be a production-level batch-inference pipeline that can process large amounts of audio data with ease.

We opted for JAX because:

  • JAX offers seamless scalability to run on one or multiple GPUs.
  • JAX provides robust support for large-scale data parallelism through its `pmap` function.

JAX encompasses numerous other advantages, including automatic differentiation of arbitrary functions, JIT compilation for performance optimization and automatic vectorization for efficient computation. 

Note: The complete code and detailed instructions for running the pipeline can be found on our GitHub repository.

A closer look at batch inference

Batch inference is a technique to process large amounts of data in parallel in order to base predictions on a batch of inputs rather than one at a time. In batch inference, a model is fed a group of data inputs all at once, which allows the system to optimize the processing of the data, reduce latency and improve the efficiency of the prediction process.

Implementing batch processing for inference can be a challenging task in production; it requires careful consideration of how to balance compute resources to avoid out-of-memory errors and optimize the use of specialized hardware like GPUs. Other factors to consider include caching repetitive operations, handling partial execution failures without needing to start from scratch and monitoring resource utilization. Building all of these features from scratch can be a time-consuming and tedious process.

Google Colab and Kaggle Notebooks are excellent environments for prototyping machine learning models — but running inference pipelines at scale requires a reliable infrastructure is needed. That's where Union Cloud comes in: It provides a robust platform for running batch inference pipelines at scale. In this post, we'll show you how to create Flyte workflows that use map tasks to run a Whisper JAX model on a V100 GPU. So let's get started!

But first, Whisper PyTorch in a single container

JAX and PyTorch are both widely used deep-learning frameworks, but JAX can often provide better performance than PyTorch. Before we move on to running Whisper JAX on Flyte, let's first run the Whisper PyTorch pipeline on Flyte.

Copied to clipboard!
import json
import os

import numpy as np
import requests
import torch
from flytekit import Resources, task, workflow
from flytekit.types.file import FlyteFile
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"


@task(requests=Resources(gpu="1", mem="15Gi", cpu="2"))
def torch_transcribe(
    checkpoint: str,
    audio: FlyteFile,
    chunk_length: float,
    batch_size: int,
    return_timestamps: bool,
) -> str:
    pipe = pipeline(
        "automatic-speech-recognition",
        model=checkpoint,
        chunk_length_s=chunk_length,
        device="cuda:0" if torch.cuda.is_available() else "cpu",
    )

    local_audio_path = audio.download()
    if local_audio_path.startswith("http://") or local_audio_path.startswith(
        "https://"
    ):
        inputs = requests.get(inputs).content
    else:
        with open(local_audio_path, "rb") as f:
            inputs = f.read()

    if isinstance(inputs, bytes):
        inputs = ffmpeg_read(inputs, 16000)

    if not isinstance(inputs, np.ndarray):
        raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")

    if len(inputs.shape) != 1:
        raise ValueError(
            "We expect a single channel audio input for AutomaticSpeechRecognitionPipeline"
        )

    prediction = pipe(
        inputs, batch_size=batch_size, return_timestamps=return_timestamps
    )
    return json.dumps(prediction)


@workflow
def torch_wf(
    checkpoint: str = "openai/whisper-large-v2",
    audio: FlyteFile = "https://huggingface.co/datasets/Samhita/whisper-jax-examples/resolve/main/khloe_kardashian_podcast.mp3",
    chunk_length: float = 30.0,
    batch_size: int = 8,
    return_timestamps: bool = False,
) -> str:
    return torch_transcribe(
        checkpoint=checkpoint,
        audio=audio,
        chunk_length=chunk_length,
        batch_size=batch_size,
        return_timestamps=return_timestamps,
    )

The pipeline in this example implements a chunking algorithm that allows for transcription of audio samples with arbitrary lengths. When triggered on the Flyte backend, the workflow is allocated 1 GPU, 2 CPUs, and 15 GB of memory for the transcription task. A batch size of 16 resulted in a CUDA out-of-memory error, but reducing the batch size to 8 allowed the workflow to complete successfully. The inference time for a 59-minute audio was approximately 5 minutes with a batch size of 8. It's worth noting that this pipeline runs entirely within a single container.

Whisper JAX in a single container

Now, let's proceed to running Whisper JAX on Flyte, in a single container.

Copied to clipboard!
import json

import jax.numpy as jnp
from flytekit import Resources, task, workflow
from flytekit.types.file import FlyteFile
from whisper_jax import FlaxWhisperPipline


@task(
    requests=Resources(gpu="1", mem="15Gi", cpu="2"),
)
def jax_transcribe(
    audio: FlyteFile,
    chunk_length_s: float,
    stride_length_s: float,
    batch_size: int,
    language: str,
    task: str,
    return_timestamps: bool,
    checkpoint: str,
) -> str:
    pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.float16, batch_size=batch_size)
    return json.dumps(
        pipeline(
            audio.download(),
            chunk_length_s,
            stride_length_s,
            batch_size,
            language,
            task,
            return_timestamps,
        )
    )


@workflow
def jax_simple_wf(
    audio: FlyteFile = "https://huggingface.co/datasets/Samhita/whisper-jax-examples/resolve/main/khloe_kardashian_podcast.mp3",
    checkpoint: str = "openai/whisper-large-v2",
    chunk_length_s: float = 30.0,
    stride_length_s: float = 5.0,
    batch_size: int = 8,
    language: str = "en",
    task: str = "transcribe",
    return_timestamps: bool = False,
):
    return jax_transcribe(
        audio=audio,
        chunk_length_s=chunk_length_s,
        stride_length_s=stride_length_s,
        batch_size=batch_size,
        language=language,
        task=task,
        return_timestamps=return_timestamps,
        checkpoint=checkpoint,
    )

The Whisper JAX pipeline also utilizes a chunking algorithm and leverages JAX's `pmap` function for data parallelism across GPU/TPU devices. The JAX transcription task is assigned 1 GPU, 2 CPUs, and 15 GB of memory. With a batch size of 8, the pipeline was able to transcribe the same 59-minute audio file in about 9 minutes on a V100 GPU.

When using a batch size of 16, the inference time averaged approximately 8 minutes.

Whoa, turns out JAX is slower than PyTorch because we can't tap into those compilation cache speed boosts. Understanding the factors that trigger JAX cache hits and misses is currently challenging. However, if the caching process becomes more configurable, we could potentially leverage it to achieve even better performance.

Parallel task fusion with map tasks

In production scenarios, it is common to generate transcriptions for a batch of audio inputs. To accomplish this efficiently, you can leverage map tasks. Flyte's map tasks allow for parallel transcription, leading to significantly faster processing times compared to running transcriptions sequentially. We've taken the code and made minor adaptations from the original Whisper JAX code

`forward` task

For the most part, the `forward` task remains true to the original code. To achieve parallel execution, we'll employ a map task that ensures the necessary resources are allocated when the task is invoked.

Copied to clipboard!
@task
def forward(
    model_inputs: List[np.ndarray],
    batch_size: Optional[int],
    language: Optional[str],
    task: Optional[str],
    return_timestamps: Optional[bool],
    max_length: Optional[int],
    checkpoint: str,
) -> List[np.ndarray]:
    model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
        checkpoint, _do_init=False, dtype=jnp.float16, cache_dir="whisper-models"
    )

    max_length = model.generation_config.max_length if max_length == 0 else max_length
    params = jax_utils.replicate(params)

    model_inputs = {
        "input_features": model_inputs[0],
        "stride": model_inputs[1].tolist(),
    }

    # We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
    input_features = model_inputs.pop("input_features")
    input_batch_size = input_features.shape[0]

    if input_batch_size != batch_size:
        padding = np.zeros(
            [batch_size - input_batch_size, *input_features.shape[1:]],
            input_features.dtype,
        )
        input_features = np.concatenate([input_features, padding])

    pred_ids = forward_generate(
        input_features=input_features,
        model=model,
        max_length=max_length,
        params=params,
        language=language,
        task=task,
        return_timestamps=return_timestamps,
    )[:input_batch_size]

    # tokenizer's decode method expects an extra dim - we insert it here for convenience
    out = {"tokens": pred_ids[:, None, :]}

    stride = model_inputs.pop("stride", None)
    if stride is not None:
        out["stride"] = stride

    return [out["tokens"], np.array(out["stride"])]

`postprocess` task

Post-processing involves consolidating the transcriptions of audio chunks, which remain mostly unchanged from the original code, except for the task decorator and resource allocation.

Copied to clipboard!
@task(requests=Resources(mem="5Gi", cpu="2", gpu="1"))
def postprocess(
    model_outputs: List[List[np.ndarray]],
    chunk_length: int,
    sampling_rate: int,
    max_source_positions: int,
    tokenizer: WhisperTokenizer,
    return_timestamps: bool,
) -> str:
    unpacked_model_outputs = []
    for output in model_outputs:
        model_output = {"tokens": output[0], "stride": output[1].tolist()}
        for t in zip(*model_output.values()):
            unpacked_model_outputs.append(dict(zip(model_output, t)))

    time_precision = chunk_length / max_source_positions
    # Send the chunking back to seconds, it's easier to handle in whisper
    for output in unpacked_model_outputs:
        if "stride" in output:
            chunk_len, stride_left, stride_right = output["stride"]
            # Go back in seconds
            chunk_len /= sampling_rate
            stride_left /= sampling_rate
            stride_right /= sampling_rate
            output["stride"] = chunk_len, stride_left, stride_right

    text, optional = tokenizer._decode_asr(
        unpacked_model_outputs,
        return_timestamps=return_timestamps,
        return_language=None,
        time_precision=time_precision,
    )
    return json.dumps({"text": text, **optional})

Bringing it all together

The last step entails invoking the `forward` and `postprocess` tasks on the audio chunks to run the transcription.

Copied to clipboard!
@dynamic(requests=Resources(mem="10Gi", cpu="4", gpu="1"))
def jax_batch_inference(
    audios: List[FlyteFile],
    checkpoint: str,
    max_length: int,
    chunk_length_s: float,
    stride_length_s: float,
    batch_size: int,
    language: str,
    task: str,
    return_timestamps: bool,
) -> List[str]:
    processor = WhisperProcessor.from_pretrained(checkpoint)
    feature_extractor = processor.feature_extractor
    tokenizer = processor.tokenizer

    transcripts = []
    for audio in audios:
        dataloader = preprocess_batch(
            inputs=audio.download(),
            feature_extractor=feature_extractor,
            chunk_length_s=chunk_length_s,
            stride_length_s=stride_length_s,
            batch_size=batch_size,
        )

        dataloader_to_list = list(
            map(
                lambda batch: [batch["input_features"], np.array(batch["stride"])],
                dataloader,
            )
        )

        # iterate over our chunked audio samples
        map_task_partial = functools.partial(
            forward,
            batch_size=batch_size,
            language=language,
            task=task,
            return_timestamps=return_timestamps,
            max_length=max_length,
            checkpoint=checkpoint,
        )
        model_outputs = map_task(map_task_partial)(
            model_inputs=dataloader_to_list
        ).with_overrides(requests=Resources(mem="20Gi", cpu="2", gpu="1"))

        transcripts.append(
            postprocess(
                model_outputs=model_outputs,
                chunk_length=feature_extractor.chunk_length,
                sampling_rate=feature_extractor.sampling_rate,
                # model.config.max_source_positions
                max_source_positions=1500,
                tokenizer=tokenizer,
                return_timestamps=return_timestamps,
            )
        )
    return transcripts


@workflow
def jax_batch_inference_wf(
    checkpoint: str = "openai/whisper-large-v2",
    max_length: int = 0,
    chunk_length_s: float = 30.0,
    stride_length_s: float = 5.0,
    batch_size: int = 16,
    language: str = "en",
    task: str = "transcribe",
    return_timestamps: bool = False,
    audios: List[FlyteFile] = [
       "https://datasets-server.huggingface.co/assets/librispeech_asr/--/all/train.clean.100/1/audio/audio.mp3",
       "https://huggingface.co/datasets/Samhita/SadTalkerData/resolve/main/Audio%20-%20Oprah%20Winfrey.mp3",
       "https://datasets-server.huggingface.co/assets/sanchit-gandhi/whisper-jax-test-files/--/sanchit-gandhi--whisper-jax-test-files/train/0/audio/audio.mp3",
       "https://datasets-server.huggingface.co/assets/sanchit-gandhi/whisper-jax-test-files/--/sanchit-gandhi--whisper-jax-test-files/train/1/audio/audio.mp3",
       "https://huggingface.co/datasets/Samhita/whisper-jax-examples/resolve/main/khloe_kardashian_podcast.mp3",
   ],
) -> List[str]:
    return jax_batch_inference(
        checkpoint=checkpoint,
        max_length=max_length,
        chunk_length_s=chunk_length_s,
        stride_length_s=stride_length_s,
        batch_size=batch_size,
        language=language,
        task=task,
        return_timestamps=return_timestamps,
        audios=audios,
    )

In addition to parallelizing the audio transcriptions, we're also running the audio chunks concurrently. This boosts speed, and it’s ideal for handling large-scale batch inference in a production environment.

The exact total runtime of the end-to-end workflow is hard to determine since map tasks run in parallel. However, it took the previous 59-minute audio 6 to 7 minutes to complete (keeping concurrency in mind). Just for comparison, when transcribing the same audio individually, it took about 5 minutes.

Support for multiple GPUs works seamlessly out of the box. You can specify the desired number of GPUs in the task decorator, such as `gpu=2`, and Flyte and JAX will handle the rest.

Thanks to the magic of map tasks, we unleashed the power of parallel processing. Batch inference is a whole different ball game compared to single model prediction, particularly concerning latency and efficient resource utilization. Union cloud lets you process large amounts of data, run batch inference and train complex models, including large language models seamlessly, all while enabling load balancing, resource allocation, scalability and optimal performance.

It's time to rev up those GPUs and make them go brrrr!

Article