diff --git a/calliope/app.py b/calliope/app.py index dafe905..39ee036 100644 --- a/calliope/app.py +++ b/calliope/app.py @@ -6,9 +6,6 @@ import threading import time from typing import Any -# Run offline — models are downloaded during setup, no need to hit HuggingFace on every launch. -os.environ.setdefault("HF_HUB_OFFLINE", "1") - import subprocess import rumps @@ -36,7 +33,7 @@ class CalliopeApp(rumps.App): self.overlay = WaveformOverlay() self.recorder = Recorder(device=cfg.get("device")) self.transcriber = Transcriber( - model=cfg.get("model", "distil-whisper/distil-large-v3"), + model=cfg.get("model", "mlx-community/whisper-large-v3-turbo"), silence_threshold=cfg.get("silence_threshold", 0.005), ) self.transcriber.context = cfg.get("context", "") @@ -281,15 +278,11 @@ class CalliopeApp(rumps.App): log.info("Typing mode set to %s", mode) def _release_transcriber(self) -> None: - """Free the current Whisper model to reclaim GPU memory.""" + """Free the current Whisper model to reclaim memory.""" import gc - import torch if self.transcriber is not None: - self.transcriber._pipe = None - self.transcriber._tokenizer = None + self.transcriber._loaded = False gc.collect() - if torch.backends.mps.is_available(): - torch.mps.empty_cache() def _on_toggle_click(self, sender) -> None: self._toggle_recording() diff --git a/calliope/config.py b/calliope/config.py index 6ffa6f3..b5ca10c 100644 --- a/calliope/config.py +++ b/calliope/config.py @@ -13,7 +13,7 @@ CONFIG_PATH = CONFIG_DIR / "config.yaml" DEFAULTS: dict[str, Any] = { "device": None, # sounddevice index; None = system default - "model": "distil-whisper/distil-large-v3", + "model": "mlx-community/whisper-large-v3-turbo", "language": "auto", "hotkeys": { "ptt": "ctrl+shift", @@ -50,11 +50,11 @@ LANGUAGES: dict[str, str] = { } MODELS: list[str] = [ - "distil-whisper/distil-large-v3", - "openai/whisper-large-v3", - "openai/whisper-base", - "openai/whisper-small", - "openai/whisper-medium", + "mlx-community/whisper-large-v3-turbo", + "mlx-community/whisper-large-v3", + "mlx-community/whisper-small", + "mlx-community/whisper-medium", + "mlx-community/whisper-base", ] diff --git a/calliope/transcriber.py b/calliope/transcriber.py index 7858819..c403a83 100644 --- a/calliope/transcriber.py +++ b/calliope/transcriber.py @@ -1,21 +1,19 @@ -"""Whisper transcription using transformers pipeline on MPS.""" +"""Whisper transcription using mlx-whisper on Apple Silicon.""" import logging import numpy as np -import torch -from transformers import pipeline log = logging.getLogger(__name__) +DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo" + class Transcriber: - def __init__(self, model: str = "distil-whisper/distil-large-v3", silence_threshold: float = 0.005): + def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005): self.model = model - self._pipe = None - self._tokenizer = None + self._loaded = False self._context: str = "" - self._cached_prompt_ids = None self.language: str = "auto" self.silence_threshold = silence_threshold @@ -26,35 +24,24 @@ class Transcriber: @context.setter def context(self, value: str) -> None: self._context = value - self._cached_prompt_ids = None # invalidate cache def load(self) -> None: - from transformers import AutoTokenizer + import mlx_whisper - device = "mps" if torch.backends.mps.is_available() else "cpu" - # Use float32 on MPS — float16 produces garbled output on Apple Silicon. - dtype = torch.float32 if device == "mps" else torch.float16 - log.info("Loading model %s on %s (dtype=%s)", self.model, device, dtype) + log.info("Loading model %s via mlx-whisper", self.model) try: - self._pipe = pipeline( - "automatic-speech-recognition", - model=self.model, - torch_dtype=dtype, - device=device, - ) - self._tokenizer = AutoTokenizer.from_pretrained(self.model) - log.info("Model loaded, running warmup...") - self._pipe( - {"raw": np.zeros(16_000, dtype=np.float32), "sampling_rate": 16_000}, - batch_size=1, + mlx_whisper.transcribe( + np.zeros(16_000, dtype=np.float32), + path_or_hf_repo=self.model, ) + self._loaded = True log.info("Model ready") except Exception: log.error("Failed to load model %s", self.model, exc_info=True) raise def transcribe(self, audio: np.ndarray) -> str: - if self._pipe is None: + if not self._loaded: self.load() if audio.size == 0: return "" @@ -68,23 +55,15 @@ class Transcriber: log.debug("Audio too short or too quiet, skipping transcription") return "" - generate_kwargs = {} - if self._context: - if self._cached_prompt_ids is None: - device = self._pipe.model.device - self._cached_prompt_ids = torch.tensor(self._tokenizer.get_prompt_ids(self._context), device=device) - generate_kwargs["prompt_ids"] = self._cached_prompt_ids + import mlx_whisper - pipe_kwargs = { - "batch_size": 1, - "generate_kwargs": generate_kwargs, + kwargs = { + "path_or_hf_repo": self.model, } + if self._context: + kwargs["initial_prompt"] = self._context if self.language != "auto": - pipe_kwargs["generate_kwargs"]["language"] = self.language - pipe_kwargs["generate_kwargs"]["task"] = "transcribe" + kwargs["language"] = self.language - result = self._pipe( - {"raw": audio, "sampling_rate": 16_000}, - **pipe_kwargs, - ) + result = mlx_whisper.transcribe(audio, **kwargs) return result["text"].strip() diff --git a/pyproject.toml b/pyproject.toml index 36b1afe..84aa3a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,7 @@ dependencies = [ "rumps>=0.4.0", "sounddevice>=0.4.6", "numpy>=1.24.0", - "torch>=2.0.0", - "transformers>=4.36.0", - "accelerate>=0.25.0", + "mlx-whisper>=0.4.0", "pynput>=1.7.6", "pyobjc-framework-Quartz>=9.0", "pyobjc-framework-Cocoa>=9.0",