Skip to content

Enable Voxtral Realtime on XNNPACK (CPU)#17431

Merged
mergennachin merged 3 commits intomainfrom
voxtral_realtime
Feb 13, 2026
Merged

Enable Voxtral Realtime on XNNPACK (CPU)#17431
mergennachin merged 3 commits intomainfrom
voxtral_realtime

Conversation

@mergennachin
Copy link
Contributor

@mergennachin mergennachin commented Feb 12, 2026

Adds Mistral's Voxtral-Mini-4B-Realtime-2602 (~4B parameter streaming
speech-to-text model) to ExecuTorch with XNNPACK backend support.

Phase 1: Self-contained eager model (model.py) with direct Mistral
checkpoint loading, multi-method export (audio_encoder, text_decoder,
token_embedding) to a single .pte, and TorchAO quantization (4bit blockwise, 8bit dynamic activation for linear layer and 8bit weight per-channel for embeddings).

Phase 2: C++ runner for offline transcription. Loads preprocessor.pte
for mel spectrogram computation, runs audio encoding, then autoregressive
decoding with element-wise audio+text embedding fusion.

Phase 3: Streaming support (follow-up PR: #17440)

Phase 4: Enable on CUDA and Metal (follow-up PR)

Example output (8da4w quantized, 30s LibriSpeech audio):

$ cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \
    --model_path voxtral_realtime.pte \
    --tokenizer_path tekken.json \
    --preprocessor_path preprocessor.pte \
    --audio_path output.wav

Mr. Quilter is the apostle of the middle classes, and we are glad to
welcome his gospel. Nor is Mr. Quilter's manner less interesting than
his matter. He tells us that at this festive season of the year, with
Christmas and roast beef looming before us, similes drawn from eating
and its results occur most readily to the mind. He has grave doubts
whether Sir Frederick Layton's work is really Greek after all, and...

Generated 392 tokens in 44s (~8.8 tok/s) on M1 Macbook

Test Plan: https://github.com/pytorch/executorch/actions/runs/21986674876/job/63522558208?pr=17431

Copilot AI review requested due to automatic review settings February 12, 2026 22:14
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17431

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Pending

As of commit c8784d8 with merge base f08db65 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 12, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@mergennachin mergennachin changed the title Enable Voxtral Realtime on XNNPACK Enable Voxtral Realtime on XNNPACK (CPU) Feb 12, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Mistral's Voxtral-Mini-4B-Realtime-2602 streaming speech-to-text model to ExecuTorch with XNNPACK backend support. The implementation is self-contained with direct checkpoint loading (no HuggingFace dependency) and includes three phases: eager model implementation with multi-method export, C++ runner for offline transcription, and hooks for future streaming support.

Changes:

  • Introduces a shared quantization module (extension/llm/export/quantize.py) for TorchAO source-transform quantization, supporting 4w/8w/8da4w/8da8w for linear layers and 4w/8w for embeddings
  • Implements Voxtral Realtime model with custom audio encoder, text decoder, and element-wise audio+text embedding fusion
  • Adds C++ runner with mel spectrogram preprocessing, audio encoding, and autoregressive decoding

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
extension/llm/export/quantize.py New shared TorchAO quantization module for LLM export with support for various weight-only and dynamic activation quantization schemes
extension/llm/export/BUCK Added quantize.py to build configuration
examples/models/voxtral_realtime/model.py Complete eager PyTorch implementation of Voxtral Realtime with causal whisper encoder, Mistral decoder, and memory-efficient checkpoint loading
examples/models/voxtral_realtime/model.md Detailed architecture documentation including design choices, ExecuTorch patterns, and checkpoint format
examples/models/voxtral_realtime/export_voxtral_rt.py Multi-method export script supporting dynamic shapes and TorchAO quantization
examples/models/voxtral_realtime/voxtral_realtime_runner.h C++ runner header defining transcription interface and config
examples/models/voxtral_realtime/voxtral_realtime_runner.cpp C++ implementation handling preprocessor execution, audio encoding, and autoregressive text generation
examples/models/voxtral_realtime/main.cpp CLI entry point with gflags configuration and stats reporting
examples/models/voxtral_realtime/README.md User-facing documentation with setup, export, build, and run instructions
examples/models/voxtral_realtime/CMakeLists.txt CMake build configuration with XNNPACK, LLM runner, and tokenizer dependencies
examples/models/voxtral_realtime/CMakePresets.json CMake presets for CPU build configuration
examples/models/parakeet/quantize.py Refactored to re-export from shared quantization module, eliminating code duplication
Makefile Added voxtral-realtime-cpu target for building the runner

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Adds Mistral's Voxtral-Mini-4B-Realtime-2602 (~4B parameter streaming
speech-to-text model) to ExecuTorch with XNNPACK backend support.

Phase 1: Self-contained eager model (model.py) with direct Mistral
checkpoint loading, multi-method export (audio_encoder, text_decoder,
token_embedding) to a single .pte, and TorchAO quantization (8da4w/8w).

Phase 2: C++ runner for offline transcription. Loads preprocessor.pte
for mel spectrogram computation, runs audio encoding, then autoregressive
decoding with element-wise audio+text embedding fusion.

Phase 3: Streaming support (follow-up PR).

Example output (8da4w quantized, 30s LibriSpeech audio):

```
$ cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \
    --model_path voxtral_realtime.pte \
    --tokenizer_path tekken.json \
    --preprocessor_path preprocessor.pte \
    --audio_path output.wav

Mr. Quilter is the apostle of the middle classes, and we are glad to
welcome his gospel. Nor is Mr. Quilter's manner less interesting than
his matter. He tells us that at this festive season of the year, with
Christmas and roast beef looming before us, similes drawn from eating
and its results occur most readily to the mind. He has grave doubts
whether Sir Frederick Layton's work is really Greek after all, and...

Generated 392 tokens in 44s (~8.8 tok/s) on M1.
```
Copilot AI review requested due to automatic review settings February 12, 2026 22:32
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

granularity = PerGroup(qlinear_group_size)

if qlinear_config == "4w":
if qlinear_packing_format:
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When qlinear_packing_format is provided, Int4WeightOnlyConfig is constructed with group_size=qlinear_group_size even if qlinear_group_size == 0 (per-axis mode). This likely creates an invalid config; consider rejecting packing_format when group_size==0, or mapping per-axis to a supported group size explicitly.

Suggested change
if qlinear_packing_format:
if qlinear_packing_format:
if qlinear_group_size == 0:
raise ValueError(
"qlinear_packing_format is not supported when qlinear_group_size == 0 "
"(per-axis quantization). Please specify a positive group size or "
"omit qlinear_packing_format."
)

Copilot uses AI. Check for mistakes.

#include <cstring>
#include <ctime>
#include <vector>
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::min is used later in this file, but <algorithm> isn’t included here. Please include <algorithm> explicitly to avoid relying on transitive includes that can break the build on some toolchains.

Suggested change
#include <vector>
#include <vector>
#include <algorithm>

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,296 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a PR adding this model to transformers (still open): huggingface/transformers#43769

Do we plan to move this to optimum-executorch once that PR is landed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manuelcandales

Yeah, when I first looked at it, the model wasn't in transformers a few days ago. FWIW, the vLLM has its own copy of the implementation in their repo, similar to what I'm doing. So, implementing directly seemed the most straightforward.

Do we plan to move this to optimum-executorch once that PR is landed?

Maybe. Once they land in transformers, we might. There are a few variables such as upgrading the transformers pin in ET -- they recently had a major 5.0 update, so I assume there will be a few breakages that need foxing. Also, I'm fine keeping as it is, if it works already.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Voxtral author here) BTW the transformers implementation only supports "offline" streaming for now for which the whole audio file is encoded in one go. The arch and forward passing logic is def still the same, but I think what we're really interested in is the "true" online / realtime use case that we implemented via the realtime API inside vLLM (see: https://docs.vllm.ai/en/latest/examples/online_serving/openai_realtime_client/?h=realtime#openai-realtime-client)

@mergennachin mergennachin temporarily deployed to upload-benchmark-results February 12, 2026 23:52 — with GitHub Actions Inactive
Copilot AI review requested due to automatic review settings February 13, 2026 02:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +20 to +21


Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of '_custom_ops' is not used.

Suggested change
_ = _custom_ops # Ensure custom ops module is imported for side effects.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 13, 2026 02:45
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +148 to +151
uint64_t prev_token = bos_id_;
int num_generated = 0;
const int64_t max_pos = std::min(
static_cast<int64_t>(config.max_new_tokens) + t_audio, max_seq_len_);
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_new_tokens is documented/flagged as a token-generation cap, but the loop bound adds t_audio, allowing up to t_audio + max_new_tokens decoding steps (and num_generated increments every step). Consider enforcing the cap based on num_generated (or renaming the field to reflect 'extra positions after audio') so CLI/docs match actual behavior.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin temporarily deployed to upload-benchmark-results February 13, 2026 04:03 — with GitHub Actions Inactive
@mergennachin mergennachin temporarily deployed to upload-benchmark-results February 13, 2026 05:57 — with GitHub Actions Inactive
Copilot AI review requested due to automatic review settings February 13, 2026 12:22
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +11 to +15
#include <cstring>
#include <ctime>
#include <vector>

#include <executorch/extension/llm/runner/llm_runner_helper.h>
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::min is used below, but <algorithm> isn’t included. This can cause a compile error depending on the standard library implementation; add #include <algorithm> explicitly.

Copilot uses AI. Check for mistakes.
Comment on lines +104 to +109
return from_blob(
mel_ref.mutable_data_ptr<float>(),
{static_cast<int>(mel_ref.size(0)),
static_cast<int>(mel_ref.size(1)),
static_cast<int>(mel_ref.size(2))},
::executorch::aten::ScalarType::Float);
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensor size casts to int can truncate large dimensions (e.g., long mel sequences). Prefer passing mel_ref.size(n) as int64_t/SizesType without narrowing casts.

Copilot uses AI. Check for mistakes.
// e. Decode token to text and emit via callback.
auto piece =
tokenizer_->decode(prev_token, static_cast<uint64_t>(next_token));
if (piece.ok()) {
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_cb is invoked unconditionally when piece.ok(). If the caller passes an empty std::function, this will throw/bad_function_call. Consider either requiring a non-empty callback (check and ET_CHECK_MSG(token_cb)), or making the callback optional and guarding before calling it.

Suggested change
if (piece.ok()) {
if (piece.ok() && token_cb) {

Copilot uses AI. Check for mistakes.
bool first_token = true;

int num_generated = runner.transcribe(
audio_data.data(),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance that there is a way to feed in audio data iteratively via some kind of generator / iterator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, followup PR coming soon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the true streaming mode:

#17440

The `t_cond` is a sinusoidal embedding of `n_delay_tokens` (default 6 = 480ms),
precomputed once and passed to each decoder layer as a constant.

### Differences from original Voxtral (non-realtime)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice summary

}

int VoxtralRealtimeRunner::transcribe(
const float* audio_data,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any chance to also provide a ::realtime interface?

@mergennachin mergennachin temporarily deployed to upload-benchmark-results February 13, 2026 13:50 — with GitHub Actions Inactive
@mergennachin mergennachin merged commit c3e60d0 into main Feb 13, 2026
354 of 357 checks passed
@mergennachin mergennachin deleted the voxtral_realtime branch February 13, 2026 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants