feat: swap Huggingface Transformers back-end to mlx native whisper.
This commit is contained in:
@@ -6,9 +6,6 @@ import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Run offline — models are downloaded during setup, no need to hit HuggingFace on every launch.
|
||||
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
||||
|
||||
import subprocess
|
||||
|
||||
import rumps
|
||||
@@ -36,7 +33,7 @@ class CalliopeApp(rumps.App):
|
||||
self.overlay = WaveformOverlay()
|
||||
self.recorder = Recorder(device=cfg.get("device"))
|
||||
self.transcriber = Transcriber(
|
||||
model=cfg.get("model", "distil-whisper/distil-large-v3"),
|
||||
model=cfg.get("model", "mlx-community/whisper-large-v3-turbo"),
|
||||
silence_threshold=cfg.get("silence_threshold", 0.005),
|
||||
)
|
||||
self.transcriber.context = cfg.get("context", "")
|
||||
@@ -281,15 +278,11 @@ class CalliopeApp(rumps.App):
|
||||
log.info("Typing mode set to %s", mode)
|
||||
|
||||
def _release_transcriber(self) -> None:
|
||||
"""Free the current Whisper model to reclaim GPU memory."""
|
||||
"""Free the current Whisper model to reclaim memory."""
|
||||
import gc
|
||||
import torch
|
||||
if self.transcriber is not None:
|
||||
self.transcriber._pipe = None
|
||||
self.transcriber._tokenizer = None
|
||||
self.transcriber._loaded = False
|
||||
gc.collect()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def _on_toggle_click(self, sender) -> None:
|
||||
self._toggle_recording()
|
||||
|
||||
@@ -13,7 +13,7 @@ CONFIG_PATH = CONFIG_DIR / "config.yaml"
|
||||
|
||||
DEFAULTS: dict[str, Any] = {
|
||||
"device": None, # sounddevice index; None = system default
|
||||
"model": "distil-whisper/distil-large-v3",
|
||||
"model": "mlx-community/whisper-large-v3-turbo",
|
||||
"language": "auto",
|
||||
"hotkeys": {
|
||||
"ptt": "ctrl+shift",
|
||||
@@ -50,11 +50,11 @@ LANGUAGES: dict[str, str] = {
|
||||
}
|
||||
|
||||
MODELS: list[str] = [
|
||||
"distil-whisper/distil-large-v3",
|
||||
"openai/whisper-large-v3",
|
||||
"openai/whisper-base",
|
||||
"openai/whisper-small",
|
||||
"openai/whisper-medium",
|
||||
"mlx-community/whisper-large-v3-turbo",
|
||||
"mlx-community/whisper-large-v3",
|
||||
"mlx-community/whisper-small",
|
||||
"mlx-community/whisper-medium",
|
||||
"mlx-community/whisper-base",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
"""Whisper transcription using transformers pipeline on MPS."""
|
||||
"""Whisper transcription using mlx-whisper on Apple Silicon."""
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo"
|
||||
|
||||
|
||||
class Transcriber:
|
||||
def __init__(self, model: str = "distil-whisper/distil-large-v3", silence_threshold: float = 0.005):
|
||||
def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005):
|
||||
self.model = model
|
||||
self._pipe = None
|
||||
self._tokenizer = None
|
||||
self._loaded = False
|
||||
self._context: str = ""
|
||||
self._cached_prompt_ids = None
|
||||
self.language: str = "auto"
|
||||
self.silence_threshold = silence_threshold
|
||||
|
||||
@@ -26,35 +24,24 @@ class Transcriber:
|
||||
@context.setter
|
||||
def context(self, value: str) -> None:
|
||||
self._context = value
|
||||
self._cached_prompt_ids = None # invalidate cache
|
||||
|
||||
def load(self) -> None:
|
||||
from transformers import AutoTokenizer
|
||||
import mlx_whisper
|
||||
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
# Use float32 on MPS — float16 produces garbled output on Apple Silicon.
|
||||
dtype = torch.float32 if device == "mps" else torch.float16
|
||||
log.info("Loading model %s on %s (dtype=%s)", self.model, device, dtype)
|
||||
log.info("Loading model %s via mlx-whisper", self.model)
|
||||
try:
|
||||
self._pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=self.model,
|
||||
torch_dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
|
||||
log.info("Model loaded, running warmup...")
|
||||
self._pipe(
|
||||
{"raw": np.zeros(16_000, dtype=np.float32), "sampling_rate": 16_000},
|
||||
batch_size=1,
|
||||
mlx_whisper.transcribe(
|
||||
np.zeros(16_000, dtype=np.float32),
|
||||
path_or_hf_repo=self.model,
|
||||
)
|
||||
self._loaded = True
|
||||
log.info("Model ready")
|
||||
except Exception:
|
||||
log.error("Failed to load model %s", self.model, exc_info=True)
|
||||
raise
|
||||
|
||||
def transcribe(self, audio: np.ndarray) -> str:
|
||||
if self._pipe is None:
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
if audio.size == 0:
|
||||
return ""
|
||||
@@ -68,23 +55,15 @@ class Transcriber:
|
||||
log.debug("Audio too short or too quiet, skipping transcription")
|
||||
return ""
|
||||
|
||||
generate_kwargs = {}
|
||||
if self._context:
|
||||
if self._cached_prompt_ids is None:
|
||||
device = self._pipe.model.device
|
||||
self._cached_prompt_ids = torch.tensor(self._tokenizer.get_prompt_ids(self._context), device=device)
|
||||
generate_kwargs["prompt_ids"] = self._cached_prompt_ids
|
||||
import mlx_whisper
|
||||
|
||||
pipe_kwargs = {
|
||||
"batch_size": 1,
|
||||
"generate_kwargs": generate_kwargs,
|
||||
kwargs = {
|
||||
"path_or_hf_repo": self.model,
|
||||
}
|
||||
if self._context:
|
||||
kwargs["initial_prompt"] = self._context
|
||||
if self.language != "auto":
|
||||
pipe_kwargs["generate_kwargs"]["language"] = self.language
|
||||
pipe_kwargs["generate_kwargs"]["task"] = "transcribe"
|
||||
kwargs["language"] = self.language
|
||||
|
||||
result = self._pipe(
|
||||
{"raw": audio, "sampling_rate": 16_000},
|
||||
**pipe_kwargs,
|
||||
)
|
||||
result = mlx_whisper.transcribe(audio, **kwargs)
|
||||
return result["text"].strip()
|
||||
|
||||
@@ -11,9 +11,7 @@ dependencies = [
|
||||
"rumps>=0.4.0",
|
||||
"sounddevice>=0.4.6",
|
||||
"numpy>=1.24.0",
|
||||
"torch>=2.0.0",
|
||||
"transformers>=4.36.0",
|
||||
"accelerate>=0.25.0",
|
||||
"mlx-whisper>=0.4.0",
|
||||
"pynput>=1.7.6",
|
||||
"pyobjc-framework-Quartz>=9.0",
|
||||
"pyobjc-framework-Cocoa>=9.0",
|
||||
|
||||
Reference in New Issue
Block a user