feat: swap Huggingface Transformers back-end to mlx native whisper.

This commit is contained in:
syntaxbullet
2026-02-18 15:43:13 +01:00
parent e429adca48
commit 85230a14a8
4 changed files with 29 additions and 59 deletions

View File

@@ -6,9 +6,6 @@ import threading
import time import time
from typing import Any from typing import Any
# Run offline — models are downloaded during setup, no need to hit HuggingFace on every launch.
os.environ.setdefault("HF_HUB_OFFLINE", "1")
import subprocess import subprocess
import rumps import rumps
@@ -36,7 +33,7 @@ class CalliopeApp(rumps.App):
self.overlay = WaveformOverlay() self.overlay = WaveformOverlay()
self.recorder = Recorder(device=cfg.get("device")) self.recorder = Recorder(device=cfg.get("device"))
self.transcriber = Transcriber( self.transcriber = Transcriber(
model=cfg.get("model", "distil-whisper/distil-large-v3"), model=cfg.get("model", "mlx-community/whisper-large-v3-turbo"),
silence_threshold=cfg.get("silence_threshold", 0.005), silence_threshold=cfg.get("silence_threshold", 0.005),
) )
self.transcriber.context = cfg.get("context", "") self.transcriber.context = cfg.get("context", "")
@@ -281,15 +278,11 @@ class CalliopeApp(rumps.App):
log.info("Typing mode set to %s", mode) log.info("Typing mode set to %s", mode)
def _release_transcriber(self) -> None: def _release_transcriber(self) -> None:
"""Free the current Whisper model to reclaim GPU memory.""" """Free the current Whisper model to reclaim memory."""
import gc import gc
import torch
if self.transcriber is not None: if self.transcriber is not None:
self.transcriber._pipe = None self.transcriber._loaded = False
self.transcriber._tokenizer = None
gc.collect() gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
def _on_toggle_click(self, sender) -> None: def _on_toggle_click(self, sender) -> None:
self._toggle_recording() self._toggle_recording()

View File

@@ -13,7 +13,7 @@ CONFIG_PATH = CONFIG_DIR / "config.yaml"
DEFAULTS: dict[str, Any] = { DEFAULTS: dict[str, Any] = {
"device": None, # sounddevice index; None = system default "device": None, # sounddevice index; None = system default
"model": "distil-whisper/distil-large-v3", "model": "mlx-community/whisper-large-v3-turbo",
"language": "auto", "language": "auto",
"hotkeys": { "hotkeys": {
"ptt": "ctrl+shift", "ptt": "ctrl+shift",
@@ -50,11 +50,11 @@ LANGUAGES: dict[str, str] = {
} }
MODELS: list[str] = [ MODELS: list[str] = [
"distil-whisper/distil-large-v3", "mlx-community/whisper-large-v3-turbo",
"openai/whisper-large-v3", "mlx-community/whisper-large-v3",
"openai/whisper-base", "mlx-community/whisper-small",
"openai/whisper-small", "mlx-community/whisper-medium",
"openai/whisper-medium", "mlx-community/whisper-base",
] ]

View File

@@ -1,21 +1,19 @@
"""Whisper transcription using transformers pipeline on MPS.""" """Whisper transcription using mlx-whisper on Apple Silicon."""
import logging import logging
import numpy as np import numpy as np
import torch
from transformers import pipeline
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo"
class Transcriber: class Transcriber:
def __init__(self, model: str = "distil-whisper/distil-large-v3", silence_threshold: float = 0.005): def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005):
self.model = model self.model = model
self._pipe = None self._loaded = False
self._tokenizer = None
self._context: str = "" self._context: str = ""
self._cached_prompt_ids = None
self.language: str = "auto" self.language: str = "auto"
self.silence_threshold = silence_threshold self.silence_threshold = silence_threshold
@@ -26,35 +24,24 @@ class Transcriber:
@context.setter @context.setter
def context(self, value: str) -> None: def context(self, value: str) -> None:
self._context = value self._context = value
self._cached_prompt_ids = None # invalidate cache
def load(self) -> None: def load(self) -> None:
from transformers import AutoTokenizer import mlx_whisper
device = "mps" if torch.backends.mps.is_available() else "cpu" log.info("Loading model %s via mlx-whisper", self.model)
# Use float32 on MPS — float16 produces garbled output on Apple Silicon.
dtype = torch.float32 if device == "mps" else torch.float16
log.info("Loading model %s on %s (dtype=%s)", self.model, device, dtype)
try: try:
self._pipe = pipeline( mlx_whisper.transcribe(
"automatic-speech-recognition", np.zeros(16_000, dtype=np.float32),
model=self.model, path_or_hf_repo=self.model,
torch_dtype=dtype,
device=device,
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
log.info("Model loaded, running warmup...")
self._pipe(
{"raw": np.zeros(16_000, dtype=np.float32), "sampling_rate": 16_000},
batch_size=1,
) )
self._loaded = True
log.info("Model ready") log.info("Model ready")
except Exception: except Exception:
log.error("Failed to load model %s", self.model, exc_info=True) log.error("Failed to load model %s", self.model, exc_info=True)
raise raise
def transcribe(self, audio: np.ndarray) -> str: def transcribe(self, audio: np.ndarray) -> str:
if self._pipe is None: if not self._loaded:
self.load() self.load()
if audio.size == 0: if audio.size == 0:
return "" return ""
@@ -68,23 +55,15 @@ class Transcriber:
log.debug("Audio too short or too quiet, skipping transcription") log.debug("Audio too short or too quiet, skipping transcription")
return "" return ""
generate_kwargs = {} import mlx_whisper
if self._context:
if self._cached_prompt_ids is None:
device = self._pipe.model.device
self._cached_prompt_ids = torch.tensor(self._tokenizer.get_prompt_ids(self._context), device=device)
generate_kwargs["prompt_ids"] = self._cached_prompt_ids
pipe_kwargs = { kwargs = {
"batch_size": 1, "path_or_hf_repo": self.model,
"generate_kwargs": generate_kwargs,
} }
if self._context:
kwargs["initial_prompt"] = self._context
if self.language != "auto": if self.language != "auto":
pipe_kwargs["generate_kwargs"]["language"] = self.language kwargs["language"] = self.language
pipe_kwargs["generate_kwargs"]["task"] = "transcribe"
result = self._pipe( result = mlx_whisper.transcribe(audio, **kwargs)
{"raw": audio, "sampling_rate": 16_000},
**pipe_kwargs,
)
return result["text"].strip() return result["text"].strip()

View File

@@ -11,9 +11,7 @@ dependencies = [
"rumps>=0.4.0", "rumps>=0.4.0",
"sounddevice>=0.4.6", "sounddevice>=0.4.6",
"numpy>=1.24.0", "numpy>=1.24.0",
"torch>=2.0.0", "mlx-whisper>=0.4.0",
"transformers>=4.36.0",
"accelerate>=0.25.0",
"pynput>=1.7.6", "pynput>=1.7.6",
"pyobjc-framework-Quartz>=9.0", "pyobjc-framework-Quartz>=9.0",
"pyobjc-framework-Cocoa>=9.0", "pyobjc-framework-Cocoa>=9.0",