feat: swap Huggingface Transformers back-end to mlx native whisper.
This commit is contained in:
@@ -6,9 +6,6 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
# Run offline — models are downloaded during setup, no need to hit HuggingFace on every launch.
|
|
||||||
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
import rumps
|
import rumps
|
||||||
@@ -36,7 +33,7 @@ class CalliopeApp(rumps.App):
|
|||||||
self.overlay = WaveformOverlay()
|
self.overlay = WaveformOverlay()
|
||||||
self.recorder = Recorder(device=cfg.get("device"))
|
self.recorder = Recorder(device=cfg.get("device"))
|
||||||
self.transcriber = Transcriber(
|
self.transcriber = Transcriber(
|
||||||
model=cfg.get("model", "distil-whisper/distil-large-v3"),
|
model=cfg.get("model", "mlx-community/whisper-large-v3-turbo"),
|
||||||
silence_threshold=cfg.get("silence_threshold", 0.005),
|
silence_threshold=cfg.get("silence_threshold", 0.005),
|
||||||
)
|
)
|
||||||
self.transcriber.context = cfg.get("context", "")
|
self.transcriber.context = cfg.get("context", "")
|
||||||
@@ -281,15 +278,11 @@ class CalliopeApp(rumps.App):
|
|||||||
log.info("Typing mode set to %s", mode)
|
log.info("Typing mode set to %s", mode)
|
||||||
|
|
||||||
def _release_transcriber(self) -> None:
|
def _release_transcriber(self) -> None:
|
||||||
"""Free the current Whisper model to reclaim GPU memory."""
|
"""Free the current Whisper model to reclaim memory."""
|
||||||
import gc
|
import gc
|
||||||
import torch
|
|
||||||
if self.transcriber is not None:
|
if self.transcriber is not None:
|
||||||
self.transcriber._pipe = None
|
self.transcriber._loaded = False
|
||||||
self.transcriber._tokenizer = None
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
def _on_toggle_click(self, sender) -> None:
|
def _on_toggle_click(self, sender) -> None:
|
||||||
self._toggle_recording()
|
self._toggle_recording()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ CONFIG_PATH = CONFIG_DIR / "config.yaml"
|
|||||||
|
|
||||||
DEFAULTS: dict[str, Any] = {
|
DEFAULTS: dict[str, Any] = {
|
||||||
"device": None, # sounddevice index; None = system default
|
"device": None, # sounddevice index; None = system default
|
||||||
"model": "distil-whisper/distil-large-v3",
|
"model": "mlx-community/whisper-large-v3-turbo",
|
||||||
"language": "auto",
|
"language": "auto",
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"ptt": "ctrl+shift",
|
"ptt": "ctrl+shift",
|
||||||
@@ -50,11 +50,11 @@ LANGUAGES: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MODELS: list[str] = [
|
MODELS: list[str] = [
|
||||||
"distil-whisper/distil-large-v3",
|
"mlx-community/whisper-large-v3-turbo",
|
||||||
"openai/whisper-large-v3",
|
"mlx-community/whisper-large-v3",
|
||||||
"openai/whisper-base",
|
"mlx-community/whisper-small",
|
||||||
"openai/whisper-small",
|
"mlx-community/whisper-medium",
|
||||||
"openai/whisper-medium",
|
"mlx-community/whisper-base",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,19 @@
|
|||||||
"""Whisper transcription using transformers pipeline on MPS."""
|
"""Whisper transcription using mlx-whisper on Apple Silicon."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo"
|
||||||
|
|
||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
def __init__(self, model: str = "distil-whisper/distil-large-v3", silence_threshold: float = 0.005):
|
def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005):
|
||||||
self.model = model
|
self.model = model
|
||||||
self._pipe = None
|
self._loaded = False
|
||||||
self._tokenizer = None
|
|
||||||
self._context: str = ""
|
self._context: str = ""
|
||||||
self._cached_prompt_ids = None
|
|
||||||
self.language: str = "auto"
|
self.language: str = "auto"
|
||||||
self.silence_threshold = silence_threshold
|
self.silence_threshold = silence_threshold
|
||||||
|
|
||||||
@@ -26,35 +24,24 @@ class Transcriber:
|
|||||||
@context.setter
|
@context.setter
|
||||||
def context(self, value: str) -> None:
|
def context(self, value: str) -> None:
|
||||||
self._context = value
|
self._context = value
|
||||||
self._cached_prompt_ids = None # invalidate cache
|
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
from transformers import AutoTokenizer
|
import mlx_whisper
|
||||||
|
|
||||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
log.info("Loading model %s via mlx-whisper", self.model)
|
||||||
# Use float32 on MPS — float16 produces garbled output on Apple Silicon.
|
|
||||||
dtype = torch.float32 if device == "mps" else torch.float16
|
|
||||||
log.info("Loading model %s on %s (dtype=%s)", self.model, device, dtype)
|
|
||||||
try:
|
try:
|
||||||
self._pipe = pipeline(
|
mlx_whisper.transcribe(
|
||||||
"automatic-speech-recognition",
|
np.zeros(16_000, dtype=np.float32),
|
||||||
model=self.model,
|
path_or_hf_repo=self.model,
|
||||||
torch_dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
|
|
||||||
log.info("Model loaded, running warmup...")
|
|
||||||
self._pipe(
|
|
||||||
{"raw": np.zeros(16_000, dtype=np.float32), "sampling_rate": 16_000},
|
|
||||||
batch_size=1,
|
|
||||||
)
|
)
|
||||||
|
self._loaded = True
|
||||||
log.info("Model ready")
|
log.info("Model ready")
|
||||||
except Exception:
|
except Exception:
|
||||||
log.error("Failed to load model %s", self.model, exc_info=True)
|
log.error("Failed to load model %s", self.model, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def transcribe(self, audio: np.ndarray) -> str:
|
def transcribe(self, audio: np.ndarray) -> str:
|
||||||
if self._pipe is None:
|
if not self._loaded:
|
||||||
self.load()
|
self.load()
|
||||||
if audio.size == 0:
|
if audio.size == 0:
|
||||||
return ""
|
return ""
|
||||||
@@ -68,23 +55,15 @@ class Transcriber:
|
|||||||
log.debug("Audio too short or too quiet, skipping transcription")
|
log.debug("Audio too short or too quiet, skipping transcription")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
generate_kwargs = {}
|
import mlx_whisper
|
||||||
if self._context:
|
|
||||||
if self._cached_prompt_ids is None:
|
|
||||||
device = self._pipe.model.device
|
|
||||||
self._cached_prompt_ids = torch.tensor(self._tokenizer.get_prompt_ids(self._context), device=device)
|
|
||||||
generate_kwargs["prompt_ids"] = self._cached_prompt_ids
|
|
||||||
|
|
||||||
pipe_kwargs = {
|
kwargs = {
|
||||||
"batch_size": 1,
|
"path_or_hf_repo": self.model,
|
||||||
"generate_kwargs": generate_kwargs,
|
|
||||||
}
|
}
|
||||||
|
if self._context:
|
||||||
|
kwargs["initial_prompt"] = self._context
|
||||||
if self.language != "auto":
|
if self.language != "auto":
|
||||||
pipe_kwargs["generate_kwargs"]["language"] = self.language
|
kwargs["language"] = self.language
|
||||||
pipe_kwargs["generate_kwargs"]["task"] = "transcribe"
|
|
||||||
|
|
||||||
result = self._pipe(
|
result = mlx_whisper.transcribe(audio, **kwargs)
|
||||||
{"raw": audio, "sampling_rate": 16_000},
|
|
||||||
**pipe_kwargs,
|
|
||||||
)
|
|
||||||
return result["text"].strip()
|
return result["text"].strip()
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ dependencies = [
|
|||||||
"rumps>=0.4.0",
|
"rumps>=0.4.0",
|
||||||
"sounddevice>=0.4.6",
|
"sounddevice>=0.4.6",
|
||||||
"numpy>=1.24.0",
|
"numpy>=1.24.0",
|
||||||
"torch>=2.0.0",
|
"mlx-whisper>=0.4.0",
|
||||||
"transformers>=4.36.0",
|
|
||||||
"accelerate>=0.25.0",
|
|
||||||
"pynput>=1.7.6",
|
"pynput>=1.7.6",
|
||||||
"pyobjc-framework-Quartz>=9.0",
|
"pyobjc-framework-Quartz>=9.0",
|
||||||
"pyobjc-framework-Cocoa>=9.0",
|
"pyobjc-framework-Cocoa>=9.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user