feat: swap Huggingface Transformers back-end to mlx native whisper.

This commit is contained in:
syntaxbullet
2026-02-18 15:43:13 +01:00
parent e429adca48
commit 85230a14a8
4 changed files with 29 additions and 59 deletions

View File

@@ -6,9 +6,6 @@ import threading
import time
from typing import Any
# Run offline — models are downloaded during setup, no need to hit HuggingFace on every launch.
os.environ.setdefault("HF_HUB_OFFLINE", "1")
import subprocess
import rumps
@@ -36,7 +33,7 @@ class CalliopeApp(rumps.App):
self.overlay = WaveformOverlay()
self.recorder = Recorder(device=cfg.get("device"))
self.transcriber = Transcriber(
model=cfg.get("model", "distil-whisper/distil-large-v3"),
model=cfg.get("model", "mlx-community/whisper-large-v3-turbo"),
silence_threshold=cfg.get("silence_threshold", 0.005),
)
self.transcriber.context = cfg.get("context", "")
@@ -281,15 +278,11 @@ class CalliopeApp(rumps.App):
log.info("Typing mode set to %s", mode)
def _release_transcriber(self) -> None:
"""Free the current Whisper model to reclaim GPU memory."""
"""Free the current Whisper model to reclaim memory."""
import gc
import torch
if self.transcriber is not None:
self.transcriber._pipe = None
self.transcriber._tokenizer = None
self.transcriber._loaded = False
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def _on_toggle_click(self, sender) -> None:
self._toggle_recording()

View File

@@ -13,7 +13,7 @@ CONFIG_PATH = CONFIG_DIR / "config.yaml"
DEFAULTS: dict[str, Any] = {
"device": None, # sounddevice index; None = system default
"model": "distil-whisper/distil-large-v3",
"model": "mlx-community/whisper-large-v3-turbo",
"language": "auto",
"hotkeys": {
"ptt": "ctrl+shift",
@@ -50,11 +50,11 @@ LANGUAGES: dict[str, str] = {
}
MODELS: list[str] = [
"distil-whisper/distil-large-v3",
"openai/whisper-large-v3",
"openai/whisper-base",
"openai/whisper-small",
"openai/whisper-medium",
"mlx-community/whisper-large-v3-turbo",
"mlx-community/whisper-large-v3",
"mlx-community/whisper-small",
"mlx-community/whisper-medium",
"mlx-community/whisper-base",
]

View File

@@ -1,21 +1,19 @@
"""Whisper transcription using transformers pipeline on MPS."""
"""Whisper transcription using mlx-whisper on Apple Silicon."""
import logging
import numpy as np
import torch
from transformers import pipeline
log = logging.getLogger(__name__)
DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo"
class Transcriber:
def __init__(self, model: str = "distil-whisper/distil-large-v3", silence_threshold: float = 0.005):
def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005):
self.model = model
self._pipe = None
self._tokenizer = None
self._loaded = False
self._context: str = ""
self._cached_prompt_ids = None
self.language: str = "auto"
self.silence_threshold = silence_threshold
@@ -26,35 +24,24 @@ class Transcriber:
@context.setter
def context(self, value: str) -> None:
self._context = value
self._cached_prompt_ids = None # invalidate cache
def load(self) -> None:
from transformers import AutoTokenizer
import mlx_whisper
device = "mps" if torch.backends.mps.is_available() else "cpu"
# Use float32 on MPS — float16 produces garbled output on Apple Silicon.
dtype = torch.float32 if device == "mps" else torch.float16
log.info("Loading model %s on %s (dtype=%s)", self.model, device, dtype)
log.info("Loading model %s via mlx-whisper", self.model)
try:
self._pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
torch_dtype=dtype,
device=device,
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
log.info("Model loaded, running warmup...")
self._pipe(
{"raw": np.zeros(16_000, dtype=np.float32), "sampling_rate": 16_000},
batch_size=1,
mlx_whisper.transcribe(
np.zeros(16_000, dtype=np.float32),
path_or_hf_repo=self.model,
)
self._loaded = True
log.info("Model ready")
except Exception:
log.error("Failed to load model %s", self.model, exc_info=True)
raise
def transcribe(self, audio: np.ndarray) -> str:
if self._pipe is None:
if not self._loaded:
self.load()
if audio.size == 0:
return ""
@@ -68,23 +55,15 @@ class Transcriber:
log.debug("Audio too short or too quiet, skipping transcription")
return ""
generate_kwargs = {}
if self._context:
if self._cached_prompt_ids is None:
device = self._pipe.model.device
self._cached_prompt_ids = torch.tensor(self._tokenizer.get_prompt_ids(self._context), device=device)
generate_kwargs["prompt_ids"] = self._cached_prompt_ids
import mlx_whisper
pipe_kwargs = {
"batch_size": 1,
"generate_kwargs": generate_kwargs,
kwargs = {
"path_or_hf_repo": self.model,
}
if self._context:
kwargs["initial_prompt"] = self._context
if self.language != "auto":
pipe_kwargs["generate_kwargs"]["language"] = self.language
pipe_kwargs["generate_kwargs"]["task"] = "transcribe"
kwargs["language"] = self.language
result = self._pipe(
{"raw": audio, "sampling_rate": 16_000},
**pipe_kwargs,
)
result = mlx_whisper.transcribe(audio, **kwargs)
return result["text"].strip()

View File

@@ -11,9 +11,7 @@ dependencies = [
"rumps>=0.4.0",
"sounddevice>=0.4.6",
"numpy>=1.24.0",
"torch>=2.0.0",
"transformers>=4.36.0",
"accelerate>=0.25.0",
"mlx-whisper>=0.4.0",
"pynput>=1.7.6",
"pyobjc-framework-Quartz>=9.0",
"pyobjc-framework-Cocoa>=9.0",