diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 331993eff47..a410e638bc5 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -210,7 +210,7 @@ if [ "$MODEL_NAME" = "parakeet" ]; then exit 0 fi -# Voxtral Realtime uses a custom export script +# Voxtral Realtime uses a custom export script (streaming mode) if [ "$MODEL_NAME" = "voxtral_realtime" ]; then pip install safetensors huggingface_hub @@ -218,22 +218,23 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then LOCAL_MODEL_DIR="${OUTPUT_DIR}/model_weights" python -c "from huggingface_hub import snapshot_download; snapshot_download('${HF_MODEL}', local_dir='${LOCAL_MODEL_DIR}')" - # Voxtral Realtime has its own quantization flags (no --qlinear_encoder) + # Per-component quantization flags VR_QUANT_ARGS="" if [ "$QUANT_NAME" = "quantized-8da4w" ]; then - VR_QUANT_ARGS="--qlinear 8da4w --qlinear-group-size 32 --qembedding 8w" + VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w" fi python -m executorch.examples.models.voxtral_realtime.export_voxtral_rt \ --model-path "$LOCAL_MODEL_DIR" \ --backend xnnpack \ + --streaming \ --output-dir "${OUTPUT_DIR}" \ ${VR_QUANT_ARGS} - # Export preprocessor + # Export streaming preprocessor (no chunk padding) python -m executorch.extension.audio.mel_spectrogram \ --feature_size 128 \ - --max_audio_len 300 \ + --streaming \ --output_file "${OUTPUT_DIR}/preprocessor.pte" test -f "${OUTPUT_DIR}/model.pte" diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 61d135ccf0c..1ba7924bd5a 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -236,7 +236,7 @@ case "$MODEL_NAME" in fi ;; voxtral_realtime) - RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" + RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0 --streaming" ;; esac diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 073ca529216..f0bcc46487d 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -31,6 +31,16 @@ This produces `preprocessor.pte` which takes a 1-D waveform tensor `(num_samples,)` and outputs a mel spectrogram `(1, 128, T_mel)`. The `--max_audio_len 300` flag supports audio up to 5 minutes. +For streaming, add `--streaming` to skip the 30-second chunk padding so +that 1280 samples (80ms) produces exactly 8 mel frames: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 128 \ + --streaming \ + --output_file ./voxtral_rt_exports/preprocessor.pte +``` + ## Export Export produces a single `.pte` file with three methods: @@ -46,12 +56,25 @@ python export_voxtral_rt.py \ --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ --backend xnnpack \ --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 8da4w \ --qlinear 8da4w \ --qembedding 8w ``` -This exports with XNNPACK backend acceleration and 8-bit dynamic activation / -4-bit weight linear quantization + 8-bit embedding quantization. +For streaming, add `--streaming` to export the encoder with KV caches for +incremental processing. This replaces `audio_encoder` with +`encode_audio_chunk` which processes 8 mel frames at a time: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend xnnpack \ + --streaming \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 8da4w \ + --qlinear 8da4w \ + --qembedding 8w +``` ### Options @@ -62,9 +85,13 @@ This exports with XNNPACK backend acceleration and 8-bit dynamic activation / | `--output-dir` | `./voxtral_rt_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length | | `--delay-tokens` | `6` | Transcription delay in tokens (6 = 480ms) | -| `--qlinear` | (none) | Linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`) | -| `--qlinear-group-size` | `32` | Group size for linear quantization | +| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`) | +| `--qlinear-group-size` | `32` | Group size for decoder linear quantization | +| `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`) | +| `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization | | `--qembedding` | (none) | Embedding layer quantization (`8w`) | +| `--streaming` | off | Export streaming encoder with KV cache | +| `--max-enc-len` | `750` | Max encoder KV cache length (~15s audio, streaming only) | ## Build @@ -91,6 +118,21 @@ cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ --audio_path input.wav ``` +For streaming, add `--streaming`. The runner feeds audio in 200ms chunks +(simulating live microphone input), computing mel and running the +encoder+decoder per 80ms step. The `StreamingSession` C++ API +(`feed_audio` / `flush`) can be used directly for integration with live +audio sources. + +```bash +cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ + --model_path voxtral_rt_exports/model.pte \ + --tokenizer_path ~/models/Voxtral-Mini-4B-Realtime-2602/tekken.json \ + --preprocessor_path voxtral_rt_exports/preprocessor.pte \ + --audio_path input.wav \ + --streaming +``` + | Flag | Default | Description | |------|---------|-------------| | `--model_path` | `model.pte` | Path to exported model | @@ -99,6 +141,7 @@ cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ | `--audio_path` | (required) | Path to 16kHz mono WAV file | | `--temperature` | `0.0` | Sampling temperature (0 = greedy) | | `--max_new_tokens` | `500` | Maximum tokens to generate | +| `--streaming` | off | Use streaming transcription | ### Example output diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index d49fd2cf57f..02587cf1d0d 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -11,8 +11,14 @@ - text_decoder: embeds (1, seq_len, 3072) + cache_position -> logits - token_embedding: token_ids (1, seq_len) -> embeds (1, seq_len, 3072) +With --streaming, produces a streaming .pte instead: + - encode_audio_chunk: mel_chunk (1,128,8) + conv states + enc_pos -> audio_embeds + new states + - text_decoder: same as above + - token_embedding: same as above + Usage: python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 + python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming """ import argparse @@ -85,37 +91,27 @@ def forward(self, token_ids: torch.Tensor) -> torch.Tensor: # --------------------------------------------------------------------------- -def export_all(model, max_seq_len): - """Export all three model components.""" - programs = {} +def _export_decoder_and_embedding( + programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding +): + """Export text_decoder and token_embedding into programs dict.""" + from executorch.extension.llm.export.quantize import quantize_model_ - # Infer dtype from model weights for sample inputs param_dtype = torch.float32 - # 1. Audio encoder - print("\nExporting audio_encoder...") - audio_encoder = AudioEncoderExport(model) - audio_encoder.eval() - - # T_mel must be a multiple of 8 (conv stride 2 + downsample 4) - _t_mel_base = Dim("_t_mel_base", min=1, max=3000) - t_mel_dim = 8 * _t_mel_base - sample_mel = torch.randn(1, model.config.num_mel_bins, 160, dtype=param_dtype) - programs["audio_encoder"] = export( - audio_encoder, - (sample_mel,), - dynamic_shapes={"mel": {2: t_mel_dim}}, - strict=False, - ) - print(f" audio_encoder exported (sample input: {sample_mel.shape})") - - # 2. Text decoder print("\nExporting text_decoder...") text_decoder = TextDecoderExport(model) text_decoder.eval() + if qlinear: + print(f" Quantizing decoder ({qlinear})...") + quantize_model_( + text_decoder, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + ) + seq_dim = Dim("seq_len", min=1, max=max_seq_len) - # Use seq_len > 1 to avoid torch.export specializing on constant 1 sample_embeds = torch.randn(1, 4, model.config.dim, dtype=param_dtype) sample_pos = torch.arange(4, dtype=torch.long) programs["text_decoder"] = export( @@ -129,11 +125,17 @@ def export_all(model, max_seq_len): ) print(f" text_decoder exported (sample input: {sample_embeds.shape})") - # 3. Token embedding print("\nExporting token_embedding...") tok_emb = TokenEmbeddingExport(model) tok_emb.eval() + if qembedding: + print(f" Quantizing embedding ({qembedding})...") + quantize_model_( + tok_emb, + qembedding_config=qembedding, + ) + tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len) sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) programs["token_embedding"] = export( @@ -144,7 +146,51 @@ def export_all(model, max_seq_len): ) print(f" token_embedding exported (sample input: {sample_ids.shape})") - # Metadata + +def export_all( + model, + max_seq_len, + qlinear_encoder=None, + qlinear_encoder_group_size=32, + qlinear=None, + qlinear_group_size=32, + qembedding=None, +): + """Export all three model components with per-component quantization.""" + from executorch.extension.llm.export.quantize import quantize_model_ + + programs = {} + param_dtype = torch.float32 + + # 1. Audio encoder + print("\nExporting audio_encoder...") + audio_encoder = AudioEncoderExport(model) + audio_encoder.eval() + + if qlinear_encoder: + print(f" Quantizing encoder ({qlinear_encoder})...") + quantize_model_( + audio_encoder, + qlinear_config=qlinear_encoder, + qlinear_group_size=qlinear_encoder_group_size, + ) + + _t_mel_base = Dim("_t_mel_base", min=1, max=3000) + t_mel_dim = 8 * _t_mel_base + sample_mel = torch.randn(1, model.config.num_mel_bins, 160, dtype=param_dtype) + programs["audio_encoder"] = export( + audio_encoder, + (sample_mel,), + dynamic_shapes={"mel": {2: t_mel_dim}}, + strict=False, + ) + print(f" audio_encoder exported (sample input: {sample_mel.shape})") + + # 2-3. Text decoder + token embedding + _export_decoder_and_embedding( + programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding + ) + metadata = { "sample_rate": 16000, "num_mel_bins": model.config.num_mel_bins, @@ -159,6 +205,100 @@ def export_all(model, max_seq_len): return programs, metadata +def export_streaming( + model, + max_seq_len, + max_enc_len=750, + qlinear_encoder=None, + qlinear_encoder_group_size=32, + qlinear=None, + qlinear_group_size=32, + qembedding=None, +): + """Export streaming model components with per-component quantization.""" + from executorch.extension.llm.export.quantize import quantize_model_ + + programs = {} + param_dtype = torch.float32 + + # 1. Streaming audio encoder + print("\nExporting encode_audio_chunk...") + from executorch.examples.models.voxtral_realtime.model import ( + StreamingAudioEncoderExport, + ) + + streaming_enc = StreamingAudioEncoderExport(model, max_enc_len=max_enc_len) + streaming_enc.eval() + + if qlinear_encoder: + print(f" Quantizing encoder ({qlinear_encoder})...") + quantize_model_( + streaming_enc, + qlinear_config=qlinear_encoder, + qlinear_group_size=qlinear_encoder_group_size, + ) + + sample_mel_chunk = torch.randn(1, model.config.num_mel_bins, 8, dtype=param_dtype) + sample_conv1_state = torch.zeros(1, model.config.num_mel_bins, 2, dtype=param_dtype) + sample_conv2_state = torch.zeros(1, model.config.enc_dim, 2, dtype=param_dtype) + sample_enc_pos = torch.arange(4, dtype=torch.long) + + programs["encode_audio_chunk"] = export( + streaming_enc, + (sample_mel_chunk, sample_conv1_state, sample_conv2_state, sample_enc_pos), + dynamic_shapes=None, + strict=False, + ) + print( + f" encode_audio_chunk exported (fixed shapes: mel_chunk={sample_mel_chunk.shape})" + ) + + # 2-3. Text decoder + token embedding + _export_decoder_and_embedding( + programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding + ) + + # Derive STFT overlap from audio parameters. + # Left overlap: next multiple of hop_length >= n_fft/2 + # Right look-ahead: how far the last mel frame extends past the step end + # mel_skip: number of overlap frames to skip at the start + hop_length = 160 + n_fft = 400 + sample_rate = 16000 + frame_rate = 12.5 + step_samples = int(sample_rate / frame_rate) + stft_left_overlap = ((n_fft // 2 + hop_length - 1) // hop_length) * hop_length + mel_skip_frames = stft_left_overlap // hop_length + chunk_mel_len = 8 + stft_right_lookahead = ( + (chunk_mel_len - 1) * hop_length + n_fft // 2 - chunk_mel_len * hop_length + ) + # = (8-1)*160 + 200 - 8*160 = 1320 - 1280 = 40 samples = 2.5ms + + metadata = { + "sample_rate": sample_rate, + "num_mel_bins": model.config.num_mel_bins, + "hop_length": hop_length, + "window_size": n_fft, + "downsample_factor": model.config.downsample_factor, + "dim": model.config.dim, + "enc_dim": model.config.enc_dim, + "vocab_size": model.config.vocab_size, + "max_seq_len": max_seq_len, + "streaming": 1, + "step_samples": step_samples, + "chunk_mel_len": chunk_mel_len, + "max_enc_len": max_enc_len, + "conv1_pad": 2, + "conv2_pad": 2, + "stft_left_overlap": stft_left_overlap, + "stft_right_lookahead": stft_right_lookahead, + "mel_skip_frames": mel_skip_frames, + } + + return programs, metadata + + def lower_to_executorch(programs, metadata, backend="xnnpack"): """Lower exported programs to ExecuTorch.""" if backend == "xnnpack": @@ -236,13 +376,25 @@ def main(): "--qlinear", default=None, choices=["4w", "8w", "8da4w", "8da8w"], - help="Quantize linear layers (e.g., 8da4w for 8-bit dynamic activation, 4-bit weight).", + help="Quantize decoder linear layers.", ) parser.add_argument( "--qlinear-group-size", type=int, default=32, - help="Group size for linear quantization (default: 32).", + help="Group size for decoder linear quantization (default: 32).", + ) + parser.add_argument( + "--qlinear-encoder", + default=None, + choices=["4w", "8w", "8da4w", "8da8w"], + help="Quantize encoder linear layers (separate from decoder).", + ) + parser.add_argument( + "--qlinear-encoder-group-size", + type=int, + default=32, + help="Group size for encoder linear quantization (default: 32).", ) parser.add_argument( "--qembedding", @@ -250,6 +402,17 @@ def main(): choices=["8w"], help="Quantize embedding layers (8-bit weight-only).", ) + parser.add_argument( + "--streaming", + action="store_true", + help="Export streaming encoder (encode_audio_chunk) instead of offline encoder.", + ) + parser.add_argument( + "--max-enc-len", + type=int, + default=750, + help="Max encoder KV cache length for streaming (default: 750, ~15s audio).", + ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -269,19 +432,21 @@ def main(): model.decoder.tok_embeddings.weight.clone() ) - from executorch.extension.llm.export.quantize import quantize_model_ - - print("\nQuantizing...") - quantize_model_( - model, - qlinear_config=args.qlinear, - qlinear_group_size=args.qlinear_group_size, - qembedding_config=args.qembedding, - ) - - # Export + # Export (quantization is applied per-component inside export functions) print("\nExporting components...") - programs, metadata = export_all(model, args.max_seq_len) + quant_args = { + "qlinear_encoder": args.qlinear_encoder, + "qlinear_encoder_group_size": args.qlinear_encoder_group_size, + "qlinear": args.qlinear, + "qlinear_group_size": args.qlinear_group_size, + "qembedding": args.qembedding, + } + if args.streaming: + programs, metadata = export_streaming( + model, args.max_seq_len, args.max_enc_len, **quant_args + ) + else: + programs, metadata = export_all(model, args.max_seq_len, **quant_args) # Lower et = lower_to_executorch(programs, metadata, backend=args.backend) diff --git a/examples/models/voxtral_realtime/main.cpp b/examples/models/voxtral_realtime/main.cpp index fc5a79661d6..5b25db580ba 100644 --- a/examples/models/voxtral_realtime/main.cpp +++ b/examples/models/voxtral_realtime/main.cpp @@ -8,8 +8,12 @@ // CLI entry point for the Voxtral Realtime transcriber. // -// Loads a .pte model, an optional preprocessor .pte, and a Tekken tokenizer. +// Loads a .pte model, a preprocessor .pte, and a Tekken tokenizer. // Processes a WAV file and prints transcribed text. +// +// Modes: +// Default: Offline transcription (full encoder, then decode) +// --streaming: Streaming transcription (incremental mel + encoder + decode) #include @@ -27,13 +31,11 @@ DEFINE_string( "model.pte", "Path to Voxtral Realtime model (.pte)."); DEFINE_string(tokenizer_path, "tekken.json", "Path to Tekken tokenizer file."); -DEFINE_string( - preprocessor_path, - "", - "Path to mel preprocessor (.pte). Required for WAV input."); +DEFINE_string(preprocessor_path, "", "Path to mel preprocessor (.pte)."); DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); DEFINE_double(temperature, 0.0, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 500, "Maximum number of tokens to generate."); +DEFINE_bool(streaming, false, "Use streaming transcription mode."); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -68,19 +70,39 @@ int main(int argc, char** argv) { stats.num_prompt_tokens = 0; bool first_token = true; - int num_generated = runner.transcribe( - audio_data.data(), - static_cast(audio_data.size()), - config, - [&](const std::string& piece) { - if (first_token) { - stats.first_token_ms = ::executorch::extension::llm::time_in_ms(); - stats.prompt_eval_end_ms = stats.first_token_ms; - first_token = false; - } - ::executorch::extension::llm::safe_printf(piece.c_str()); - fflush(stdout); - }); + auto token_cb = [&](const std::string& piece) { + if (first_token) { + stats.first_token_ms = ::executorch::extension::llm::time_in_ms(); + stats.prompt_eval_end_ms = stats.first_token_ms; + first_token = false; + } + ::executorch::extension::llm::safe_printf(piece.c_str()); + fflush(stdout); + }; + + int num_generated; + if (FLAGS_streaming) { + ET_CHECK_MSG( + runner.is_streaming(), + "Model was not exported with --streaming. Re-export with --streaming flag."); + auto session = runner.create_streaming_session(config, token_cb); + + // Feed audio in 80ms chunks (one streaming step = 1280 samples at 16kHz). + const int64_t chunk_size = 1280; + for (int64_t offset = 0; offset < static_cast(audio_data.size()); + offset += chunk_size) { + int64_t n = std::min( + chunk_size, static_cast(audio_data.size()) - offset); + session->feed_audio(audio_data.data() + offset, n); + } + num_generated = session->flush(); + } else { + num_generated = runner.transcribe( + audio_data.data(), + static_cast(audio_data.size()), + config, + token_cb); + } printf("\n"); diff --git a/examples/models/voxtral_realtime/model.md b/examples/models/voxtral_realtime/model.md index 9cf1bd9d0b4..1de93376997 100644 --- a/examples/models/voxtral_realtime/model.md +++ b/examples/models/voxtral_realtime/model.md @@ -146,9 +146,92 @@ The encoder has no KV cache (processes full mel at once in offline mode) and no GQA (n_heads == n_kv_heads). Uses `F.scaled_dot_product_attention` with `is_causal=True` and standard `[B, H, T, D]` layout. No custom ops needed. -Uses full causal attention (no sliding window of 750) — acceptable for -offline mode and simpler for export. Sliding window would be added for -streaming (Phase 3). +Uses full causal attention (no sliding window of 750). The model's +`params.json` specifies `sliding_window: 750` but this is not enforced +in the ExecuTorch implementation — the KV cache (streaming) or full +attention (offline) provides equivalent or broader context. + +## Streaming Encoder (`StreamingAudioEncoderExport`) + +For streaming/live transcription, the encoder processes audio incrementally +(8 mel frames = 80ms per step) instead of the full mel at once. + +### Architecture + +`StreamingAudioEncoderExport` shares all weights with the offline encoder +but uses a different forward path: + +``` +mel_chunk (1, 128, 8) + + conv1_state (1, 128, 2) + conv2_state (1, 1280, 2) + -> cat(state, chunk) -> raw Conv1d (no CausalConv1d padding) -> GELU + -> cat(state, conv1_out) -> raw Conv1d -> GELU +(1, 1280, 4) -> transpose -> (1, 4, 1280) + -> 32x streaming encoder layer (KV cache + custom_sdpa) + -> RMSNorm +(1, 4, 1280) + -> Reshape downsample (1, 1, 5120) -> Adapter (1, 1, 3072) +-> audio_embeds, new_conv1_state, new_conv2_state +``` + +### Conv state management + +The causal convolutions need left context across chunk boundaries. +Instead of zero-padding (offline) or recompute-with-overlap (vLLM), +explicit conv state carries the tail of the previous chunk: + +- **Conv1** (kernel=3, stride=1): state = last 2 mel frames from previous + chunk. `cat(state, chunk)` → (1, 128, 10) → Conv1d → (1, 1280, 8). +- **Conv2** (kernel=3, stride=2): state = last 2 conv1 GELU output frames. + `cat(state, conv1_out)` → (1, 1280, 10) → Conv1d → (1, 1280, 4). + +The raw `nn.Conv1d` is called directly (bypassing `CausalConv1d.forward` +which would zero-pad). This produces identical results to the offline +encoder — verified to within fp32 precision (max diff < 2e-5). + +### Encoder KV cache + +Each of the 32 encoder transformer layers gets its own `KVCache` instance +(reusing the same class as the decoder). The `SDPA` module handles causal +attention via `start_pos`, accumulating encoder frames incrementally. + +- Cache shape: `(1, max_enc_len, 32, 64)` per layer +- Default `max_enc_len=750` (~15s audio, matching the model's trained + sliding window). Configurable via `--max-enc-len`. +- Memory: 32 layers × 2 × 750 × 32 × 64 × 4 bytes ≈ 393 MB (fp32) + +### STFT overlap for streaming mel + +The streaming preprocessor (`WhisperAudioProcessor(streaming=True)`) +computes mel without 30-second chunk padding. To match offline mel values +at chunk boundaries, the C++ runner uses overlapping audio windows: + +- **Left overlap**: 320 samples (2 × hop_length, ≥ n_fft/2 = 200) +- **Right look-ahead**: 40 samples (2.5ms, matches vLLM's + `streaming_look_ahead_ms`) +- **Total window**: 320 + 1280 + 40 = 1640 samples → 10 mel frames +- **Frame extraction**: skip first 2 frames (overlap region), take + frames 2–9 (the 8 that align with offline mel frame positions) + +For the first step, the left overlap is zero-padded (matching the +offline encoder's `center=True` STFT edge behavior). The 2.5ms +look-ahead introduces negligible latency. + +### Per-component quantization + +Quantization is applied per-component after wrapping (following the +Parakeet pattern), allowing different configs for encoder vs decoder: + +```bash +--qlinear-encoder 8w # encoder linear layers +--qlinear 8da4w # decoder linear layers +--qembedding 8w # embedding layer +``` + +The streaming encoder references the same module objects that +`quantize_model_()` mutates in-place, so quantized weights are +used transparently. Conv1d layers are not quantized (not targeted +by `quantize_model_`). KV caches and SDPA have no trainable weights. ## Checkpoint Format @@ -215,4 +298,14 @@ VoxtralRealtimeModel feed_forward: LMMLP (w1/w2/w3) norm: RMSNorm output: Linear (tied to tok_embeddings) + +StreamingAudioEncoderExport (export wrapper, shares weights with encoder + adapter) + conv1: nn.Conv1d (shared from encoder.conv_layers[0].conv) + conv2: nn.Conv1d (shared from encoder.conv_layers[1].conv) + layers: 32x CausalEncoderLayer (shared from encoder.layers) + enc_norm: RMSNorm (shared from encoder.norm) + adapter: AudioLanguageAdapter (shared from model.adapter) + kv_caches: 32x KVCache (owned, for streaming attention) + sdpa: SDPA (owned, for streaming attention) + freqs_cos/sin: RoPE buffers (owned, encoder dims) ``` diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 3b42839b409..b775d65f0d4 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -560,6 +560,124 @@ def token_embedding(self, token_ids: torch.Tensor) -> torch.Tensor: return self.decoder.tok_embeddings(token_ids) +# --------------------------------------------------------------------------- +# Streaming encoder +# --------------------------------------------------------------------------- + + +class StreamingAudioEncoderExport(nn.Module): + """Streaming encoder: processes one 8-mel-frame chunk at a time. + + Shares conv/transformer/adapter weights with the offline encoder. + Owns separate KV caches and SDPA for incremental KV-cached attention. + + Forward: + mel_chunk(1,128,8) + conv1_state(1,128,2) + conv2_state(1,1280,2) + + enc_input_pos(4,) + -> audio_embeds(1,1,3072), new_conv1_state(1,128,2), new_conv2_state(1,1280,2) + """ + + def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): + super().__init__() + config = model.config + + # Shared encoder weights (read-only references, never mutated) + self.conv1 = model.encoder.conv_layers[0].conv + self.conv2 = model.encoder.conv_layers[1].conv + self.layers = model.encoder.layers + self.enc_norm = model.encoder.norm + self.adapter = model.adapter + + self.downsample_factor = config.downsample_factor + self.n_heads = config.enc_n_heads + self.head_dim = config.enc_head_dim + + # Streaming-specific: encoder KV caches (one per layer) + self.kv_caches = nn.ModuleList( + [ + KVCache(max_enc_len, config.enc_n_heads, config.enc_head_dim) + for _ in range(config.enc_n_layers) + ] + ) + + # SDPA for encoder MHA (n_heads=32, head_dim=64 -> attn_dim=2048) + self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) + + # RoPE for encoder dimensions + freqs_cos, freqs_sin = precompute_freqs_cis( + config.enc_head_dim, max_enc_len, config.enc_rope_theta + ) + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def _streaming_encoder_layer( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: torch.Tensor, + layer: CausalEncoderLayer, + layer_idx: int, + ) -> torch.Tensor: + """One encoder layer with streaming attention (KV cache + custom_sdpa).""" + h = layer.attention_norm(x) + + B, T, _ = h.shape + attn = layer.attention + q = attn.wq(h).view(B, T, self.n_heads, self.head_dim) + k = attn.wk(h).view(B, T, self.n_heads, self.head_dim) + v = attn.wv(h).view(B, T, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) + k, v = self.kv_caches[layer_idx].update(input_pos, k, v) + y = self.sdpa(input_pos, q, k, v, B, T) + y = attn.wo(y) + + x = x + y + x = x + layer.feed_forward(layer.ffn_norm(x)) + return x + + def forward( + self, + mel_chunk: torch.Tensor, + conv1_state: torch.Tensor, + conv2_state: torch.Tensor, + enc_input_pos: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Conv1: cat state + chunk, raw Conv1d (no CausalConv1d padding) + # (1, 128, 2+8=10) -> conv1(k=3, s=1) -> (1, 1280, 8) + conv1_input = torch.cat([conv1_state, mel_chunk], dim=2) + conv1_out = F.gelu(self.conv1(conv1_input)) + new_conv1_state = mel_chunk[:, :, -2:] + + # Conv2: cat state + conv1_out, raw Conv1d + # (1, 1280, 2+8=10) -> conv2(k=3, s=2) -> (1, 1280, 4) + conv2_input = torch.cat([conv2_state, conv1_out], dim=2) + conv2_out = F.gelu(self.conv2(conv2_input)) + new_conv2_state = conv1_out[:, :, -2:] + + x = conv2_out.transpose(1, 2) # (1, 4, 1280) + + # Encoder transformer with KV cache + freqs_cos = self.freqs_cos[enc_input_pos] + freqs_sin = self.freqs_sin[enc_input_pos] + + for i, layer in enumerate(self.layers): + x = self._streaming_encoder_layer( + x, freqs_cos, freqs_sin, enc_input_pos, layer, i + ) + + x = self.enc_norm(x) # (1, 4, 1280) + + # Downsample: concat 4 consecutive frames -> (1, 1, 5120) + B, T, D = x.shape + x = x.reshape(B, T // self.downsample_factor, D * self.downsample_factor) + + audio_embeds = self.adapter(x) # (1, 1, 3072) + + return audio_embeds, new_conv1_state, new_conv2_state + + # --------------------------------------------------------------------------- # Weight loading # --------------------------------------------------------------------------- diff --git a/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp b/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp index f5b2a88bb81..226f397bd50 100644 --- a/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp +++ b/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp @@ -60,6 +60,52 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( static_cast(vocab_size_), static_cast(dim_)); + // Detect streaming model (exported with --streaming flag). + auto streaming_val = model_->execute("streaming", empty); + if (streaming_val.ok() && streaming_val.get()[0].toInt() == 1) { + is_streaming_ = true; + + auto nmb = model_->execute("num_mel_bins", empty); + if (nmb.ok()) + num_mel_bins_ = nmb.get()[0].toInt(); + auto cm = model_->execute("chunk_mel_len", empty); + if (cm.ok()) + chunk_mel_len_ = cm.get()[0].toInt(); + auto me = model_->execute("max_enc_len", empty); + if (me.ok()) + max_enc_len_ = me.get()[0].toInt(); + auto ed = model_->execute("enc_dim", empty); + if (ed.ok()) + enc_dim_ = ed.get()[0].toInt(); + auto c1 = model_->execute("conv1_pad", empty); + if (c1.ok()) + conv1_pad_ = c1.get()[0].toInt(); + auto c2 = model_->execute("conv2_pad", empty); + if (c2.ok()) + conv2_pad_ = c2.get()[0].toInt(); + + auto ss = model_->execute("step_samples", empty); + if (ss.ok()) + step_samples_ = ss.get()[0].toInt(); + + auto slo = model_->execute("stft_left_overlap", empty); + if (slo.ok()) + stft_left_overlap_ = slo.get()[0].toInt(); + auto srl = model_->execute("stft_right_lookahead", empty); + if (srl.ok()) + stft_right_lookahead_ = srl.get()[0].toInt(); + auto msf = model_->execute("mel_skip_frames", empty); + if (msf.ok()) + mel_skip_frames_ = msf.get()[0].toInt(); + + ET_LOG( + Info, + "Streaming: chunk_mel=%ld, max_enc=%ld, enc_dim=%ld", + static_cast(chunk_mel_len_), + static_cast(max_enc_len_), + static_cast(enc_dim_)); + } + // Tekken tokenizer (tekken.json) for the Mistral vocabulary. ET_LOG(Info, "Loading tokenizer from: %s", tokenizer_path.c_str()); tokenizer_ = ::executorch::extension::llm::load_tokenizer(tokenizer_path); @@ -224,4 +270,293 @@ int VoxtralRealtimeRunner::transcribe( return num_generated; } +// --------------------------------------------------------------------------- +// StreamingSession +// --------------------------------------------------------------------------- + +std::unique_ptr +VoxtralRealtimeRunner::create_streaming_session( + const TranscribeConfig& config, + TokenCallback token_cb) { + ET_CHECK_MSG(is_streaming_, "Model was not exported with --streaming."); + ET_CHECK_MSG( + preprocessor_ != nullptr, + "No preprocessor loaded. Provide --preprocessor_path."); + return std::make_unique(*this, config, std::move(token_cb)); +} + +StreamingSession::StreamingSession( + VoxtralRealtimeRunner& runner, + TranscribeConfig config, + TokenCallback token_cb) + : runner_(runner), + config_(config), + token_cb_(std::move(token_cb)), + prev_token_(runner.bos_id_), + sampler_( + static_cast(runner.vocab_size_), + config.temperature, + ::executorch::extension::llm::kTopp, + static_cast(std::time(nullptr))), + input_embeds_buf_(static_cast(runner.dim_)) { + // Initialize conv states to zero (matches offline encoder's left-padding). + // num_mel_bins=128, conv1_pad_=2 → 128*2 = 256 floats + conv1_state_.assign( + static_cast(runner.num_mel_bins_ * runner.conv1_pad_), 0.0f); + // enc_dim_=1280, conv2_pad_=2 → 1280*2 = 2560 floats + conv2_state_.assign( + static_cast(runner.enc_dim_ * runner.conv2_pad_), 0.0f); +} + +int StreamingSession::feed_audio(const float* data, int64_t num_samples) { + audio_buf_.insert(audio_buf_.end(), data, data + num_samples); + + int new_tokens = 0; + while (!eos_reached_ && try_process_step()) { + new_tokens++; + } + + // Trim consumed audio to bound memory growth. Keep stft_left_overlap_ + // samples before samples_consumed_ for the next step's left context. + int64_t keep_from = samples_consumed_ - runner_.stft_left_overlap_; + if (keep_from > 0) { + audio_buf_.erase( + audio_buf_.begin(), + audio_buf_.begin() + static_cast(keep_from)); + samples_consumed_ -= keep_from; + } + + return new_tokens; +} + +bool StreamingSession::try_process_step() { + const int64_t step = runner_.step_samples_; + const int64_t left_overlap = runner_.stft_left_overlap_; + const int64_t right_lookahead = runner_.stft_right_lookahead_; + const int64_t mel_skip = runner_.mel_skip_frames_; + const int64_t chunk_mel_len = runner_.chunk_mel_len_; + + // Need enough audio for: current step + right look-ahead. + // Left overlap comes from audio before samples_consumed_ (already in buffer). + const int64_t need_end = samples_consumed_ + step + right_lookahead; + if (static_cast(audio_buf_.size()) < need_end) { + return false; + } + + // Guard: encoder/decoder cache capacity. + const int64_t enc_frames_per_chunk = chunk_mel_len / 2; + if (enc_frame_pos_ + enc_frames_per_chunk > runner_.max_enc_len_ || + dec_pos_ >= runner_.max_seq_len_) { + return false; + } + + // --- Build the overlapping audio window --- + // Window: [left_overlap] + [step (1280)] + [right_lookahead (40)] = 1640 + // samples For the first step (samples_consumed_=0), left side is zero-padded. + const int64_t window_size = left_overlap + step + right_lookahead; + std::vector window_buf(static_cast(window_size), 0.0f); + + // Left overlap: copy from audio_buf_ before samples_consumed_ + int64_t left_start = samples_consumed_ - left_overlap; + if (left_start >= 0) { + std::memcpy( + window_buf.data(), + audio_buf_.data() + left_start, + static_cast(left_overlap) * sizeof(float)); + } else { + // Partial left overlap (first step): zero-pad then copy available + int64_t available_left = samples_consumed_; + int64_t zero_pad = left_overlap - available_left; + // window_buf[0..zero_pad) is already 0.0f + if (available_left > 0) { + std::memcpy( + window_buf.data() + zero_pad, + audio_buf_.data(), + static_cast(available_left) * sizeof(float)); + } + } + + // Step + right look-ahead + std::memcpy( + window_buf.data() + left_overlap, + audio_buf_.data() + samples_consumed_, + static_cast(step + right_lookahead) * sizeof(float)); + + // --- Compute mel spectrogram on the full window --- + auto audio_tensor = from_blob( + window_buf.data(), + {static_cast(window_size)}, + ::executorch::aten::ScalarType::Float); + + auto mel_result = runner_.preprocessor_->execute( + "forward", std::vector{*audio_tensor}); + ET_CHECK_MSG(mel_result.ok(), "Streaming preprocessor failed."); + + auto mel = mel_result.get()[0].toTensor(); + // mel shape: (1, 128, 10) — 10 frames from 1640 samples with center=True + const int64_t num_mel_bins = mel.size(1); + const int64_t total_mel_frames = mel.size(2); + + ET_CHECK_MSG( + total_mel_frames >= mel_skip + chunk_mel_len, + "Preprocessor produced fewer mel frames than expected."); + + // --- Extract frames [mel_skip, mel_skip+8) = frames 2-9 --- + // These align exactly with the offline mel frames for this step. + // Output layout is channels-first: (1, 128, T). For each channel, + // copy 8 contiguous frames starting at offset mel_skip. + std::vector mel_chunk_buf( + static_cast(num_mel_bins * chunk_mel_len)); + const float* mel_data = mel.const_data_ptr(); + for (int64_t c = 0; c < num_mel_bins; c++) { + std::memcpy( + mel_chunk_buf.data() + c * chunk_mel_len, + mel_data + c * total_mel_frames + mel_skip, + static_cast(chunk_mel_len) * sizeof(float)); + } + + auto mel_chunk = from_blob( + mel_chunk_buf.data(), + {1, static_cast(num_mel_bins), static_cast(chunk_mel_len)}, + ::executorch::aten::ScalarType::Float); + + auto conv1_state = from_blob( + conv1_state_.data(), + {1, static_cast(num_mel_bins), static_cast(runner_.conv1_pad_)}, + ::executorch::aten::ScalarType::Float); + + auto conv2_state = from_blob( + conv2_state_.data(), + {1, + static_cast(runner_.enc_dim_), + static_cast(runner_.conv2_pad_)}, + ::executorch::aten::ScalarType::Float); + + std::vector enc_pos_data(static_cast(enc_frames_per_chunk)); + for (int64_t i = 0; i < enc_frames_per_chunk; i++) { + enc_pos_data[static_cast(i)] = enc_frame_pos_ + i; + } + auto enc_pos = from_blob( + enc_pos_data.data(), + {static_cast(enc_frames_per_chunk)}, + ::executorch::aten::ScalarType::Long); + + // --- Run streaming encoder --- + auto enc_result = runner_.model_->execute( + "encode_audio_chunk", + std::vector{*mel_chunk, *conv1_state, *conv2_state, *enc_pos}); + ET_CHECK_MSG(enc_result.ok(), "encode_audio_chunk failed."); + + auto& enc_outputs = enc_result.get(); + auto audio_embeds = enc_outputs[0].toTensor(); + auto new_conv1 = enc_outputs[1].toTensor(); + auto new_conv2 = enc_outputs[2].toTensor(); + + std::memcpy( + conv1_state_.data(), + new_conv1.const_data_ptr(), + conv1_state_.size() * sizeof(float)); + std::memcpy( + conv2_state_.data(), + new_conv2.const_data_ptr(), + conv2_state_.size() * sizeof(float)); + enc_frame_pos_ += enc_frames_per_chunk; + samples_consumed_ += step; + + // --- Decode one step --- + return decode_step(audio_embeds.const_data_ptr()); +} + +bool StreamingSession::decode_step(const float* audio_embeds) { + // Token embedding for previous token. + int64_t token_id = static_cast(prev_token_); + auto token_tensor = + from_blob(&token_id, {1, 1}, ::executorch::aten::ScalarType::Long); + + auto tok_result = runner_.model_->execute( + "token_embedding", std::vector{*token_tensor}); + ET_CHECK_MSG(tok_result.ok(), "token_embedding failed."); + auto tok_embed = tok_result.get()[0].toTensor(); + const float* tok_data = tok_embed.const_data_ptr(); + + // Sum audio + token embeddings (or token-only if audio_embeds is null). + if (audio_embeds != nullptr) { + for (int64_t i = 0; i < runner_.dim_; i++) { + input_embeds_buf_[static_cast(i)] = audio_embeds[i] + tok_data[i]; + } + } else { + std::memcpy( + input_embeds_buf_.data(), + tok_data, + static_cast(runner_.dim_) * sizeof(float)); + } + + auto input_embeds = from_blob( + input_embeds_buf_.data(), + {1, 1, static_cast(runner_.dim_)}, + ::executorch::aten::ScalarType::Float); + + auto cache_pos = + from_blob(&dec_pos_, {1}, ::executorch::aten::ScalarType::Long); + + auto dec_result = runner_.model_->execute( + "text_decoder", std::vector{*input_embeds, *cache_pos}); + ET_CHECK_MSG(dec_result.ok(), "text_decoder failed."); + + auto logits = dec_result.get()[0].toTensor(); + float* logits_data = + logits.mutable_data_ptr() + (logits.numel() - runner_.vocab_size_); + int64_t next_token = static_cast(sampler_.sample(logits_data)); + num_generated_++; + + auto piece = runner_.tokenizer_->decode( + prev_token_, static_cast(next_token)); + if (piece.ok()) { + token_cb_(*piece); + } + + if (static_cast(next_token) == runner_.eos_id_) { + eos_reached_ = true; + return true; + } + + prev_token_ = static_cast(next_token); + dec_pos_++; + return true; +} + +int StreamingSession::flush() { + if (flushed_) { + return num_generated_; + } + flushed_ = true; + + // Pad with silence so any remaining audio (including partial steps and + // the right look-ahead for the last complete step) can be processed. + const int64_t remaining = + static_cast(audio_buf_.size()) - samples_consumed_; + if (remaining > 0 && !eos_reached_) { + const int64_t step = runner_.step_samples_; + const int64_t right_lookahead = runner_.stft_right_lookahead_; + // Pad to next full step + right look-ahead + int64_t pad_to = ((remaining + step - 1) / step) * step + right_lookahead; + std::vector silence(static_cast(pad_to - remaining), 0.0f); + audio_buf_.insert(audio_buf_.end(), silence.begin(), silence.end()); + + while (!eos_reached_ && try_process_step()) { + } + } + + // Text-only decoding after audio ends. + const int64_t max_text_steps = std::min( + static_cast(config_.max_new_tokens) - num_generated_, + runner_.max_seq_len_ - dec_pos_); + + for (int64_t i = 0; i < max_text_steps && !eos_reached_; i++) { + decode_step(nullptr); + } + + return num_generated_; +} + } // namespace voxtral_realtime diff --git a/examples/models/voxtral_realtime/voxtral_realtime_runner.h b/examples/models/voxtral_realtime/voxtral_realtime_runner.h index b53ed774022..35ad1903383 100644 --- a/examples/models/voxtral_realtime/voxtral_realtime_runner.h +++ b/examples/models/voxtral_realtime/voxtral_realtime_runner.h @@ -12,7 +12,9 @@ #include #include #include +#include +#include #include #include #include @@ -30,6 +32,8 @@ struct TranscribeConfig { using TokenCallback = std::function; +class StreamingSession; + class VoxtralRealtimeRunner { public: VoxtralRealtimeRunner( @@ -37,21 +41,33 @@ class VoxtralRealtimeRunner { const std::string& tokenizer_path, const std::string& preprocessor_path = ""); - // Transcribe audio. Returns the number of generated tokens. + // Offline transcription: full encoder first, then step-by-step decode. int transcribe( const float* audio_data, int64_t num_samples, const TranscribeConfig& config, TokenCallback token_cb); + // Streaming transcription: processes raw audio incrementally via + // StreamingSession. Requires a model exported with --streaming and + // a streaming preprocessor .pte. + std::unique_ptr create_streaming_session( + const TranscribeConfig& config, + TokenCallback token_cb); + int64_t max_seq_len() const { return max_seq_len_; } int64_t vocab_size() const { return vocab_size_; } + bool is_streaming() const { + return is_streaming_; + } private: + friend class StreamingSession; + std::unique_ptr<::executorch::extension::Module> model_; std::unique_ptr<::executorch::extension::Module> preprocessor_; std::unique_ptr tokenizer_; @@ -61,6 +77,23 @@ class VoxtralRealtimeRunner { int64_t vocab_size_ = 131072; int64_t dim_ = 3072; + // Streaming metadata (from constant_methods, if present) + bool is_streaming_ = false; + int64_t num_mel_bins_ = 128; + int64_t chunk_mel_len_ = 8; + int64_t max_enc_len_ = 750; + int64_t enc_dim_ = 1280; + int64_t conv1_pad_ = 2; + int64_t conv2_pad_ = 2; + + // Raw audio samples per streaming step (sampling_rate / frame_rate = 1280) + int64_t step_samples_ = 1280; + + // STFT overlap for streaming mel computation (read from model metadata). + int64_t stft_left_overlap_ = 320; + int64_t stft_right_lookahead_ = 40; + int64_t mel_skip_frames_ = 2; + // Tokenizer special tokens uint64_t bos_id_ = 1; uint64_t eos_id_ = 2; @@ -71,4 +104,57 @@ class VoxtralRealtimeRunner { int64_t num_samples); }; +// Streaming session: accepts raw audio incrementally via feed_audio(), +// computes mel spectrogram per step, and runs encoder+decoder in real-time. +class StreamingSession { + public: + StreamingSession( + VoxtralRealtimeRunner& runner, + TranscribeConfig config, + TokenCallback token_cb); + + // Feed raw audio (16kHz float32). Processes as many complete 80ms steps + // as possible. Returns number of new tokens generated. + int feed_audio(const float* data, int64_t num_samples); + + // Signal end of audio. Pads last partial step, then generates remaining + // text-only tokens until EOS or max_new_tokens. Returns total tokens + // generated across the entire session. + int flush(); + + int total_tokens() const { + return num_generated_; + } + + private: + VoxtralRealtimeRunner& runner_; + TranscribeConfig config_; + TokenCallback token_cb_; + + // Raw audio accumulation buffer + std::vector audio_buf_; + int64_t samples_consumed_ = 0; + + // Encoder streaming state + std::vector conv1_state_; + std::vector conv2_state_; + int64_t enc_frame_pos_ = 0; + + // Decoder state + int64_t dec_pos_ = 0; + uint64_t prev_token_; + int num_generated_ = 0; + bool eos_reached_ = false; + bool flushed_ = false; + + ::executorch::extension::llm::Sampler sampler_; + std::vector input_embeds_buf_; + + // Process one 80ms step from the audio buffer. + bool try_process_step(); + + // Run one decoder step (token_embed + optional audio_embed -> logits). + bool decode_step(const float* audio_embeds); +}; + } // namespace voxtral_realtime diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index e02b34fc44c..50b9ded01af 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -51,6 +51,7 @@ def __init__( padding_value: float = 0.0, max_audio_len: int = 600, stack_output: bool = False, + streaming: bool = False, ) -> None: super().__init__() self.feature_size = feature_size @@ -68,6 +69,13 @@ def __init__( ) self.max_audio_len = max_audio_len self.stack_output = stack_output + self.streaming = streaming + + if self.streaming and self.stack_output: + raise ValueError( + "--streaming and --stack_output are mutually exclusive. " + "stack_output assumes 30-second chunk padding which streaming disables." + ) def get_mel_filters( self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32 @@ -130,15 +138,18 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Output of shape [1, feature_size, nb_max_frames * n_chunks] n_chunks is the number of chunks of `sampling_rate` samples in the input waveform. - [1, 80, 3000] with default options and 1 chunk + [1, 80, 3000] with default options and 1 chunk. + In streaming mode, output shape is [1, feature_size, floor(N/hop_length)] + with no chunk padding. """ - n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1 - waveform = F.pad( - waveform, - (0, self.n_samples * n_chunks - waveform.shape[0]), - mode="constant", - value=self.padding_value, - ) + if not self.streaming: + n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1 + waveform = F.pad( + waveform, + (0, self.n_samples * n_chunks - waveform.shape[0]), + mode="constant", + value=self.padding_value, + ) # Ideally we should do: # window = torch.hann_window(self.n_fft) @@ -259,6 +270,11 @@ def main(): action="store_true", help="Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information.", ) + parser.add_argument( + "--streaming", + action="store_true", + help="Streaming mode: skip 30-second chunk padding, produce mel frames proportional to input length. For use with real-time audio input.", + ) args = parser.parse_args() @@ -270,6 +286,7 @@ def main(): n_fft=args.n_fft, max_audio_len=args.max_audio_len, stack_output=args.stack_output, + streaming=args.streaming, ) export_processor(model, args.output_file)