70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
"""Whisper transcription using mlx-whisper on Apple Silicon."""
|
|
|
|
import logging
|
|
|
|
import numpy as np
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MODEL = "mlx-community/whisper-large-v3-turbo"
|
|
|
|
|
|
class Transcriber:
|
|
def __init__(self, model: str = DEFAULT_MODEL, silence_threshold: float = 0.005):
|
|
self.model = model
|
|
self._loaded = False
|
|
self._context: str = ""
|
|
self.language: str = "auto"
|
|
self.silence_threshold = silence_threshold
|
|
|
|
@property
|
|
def context(self) -> str:
|
|
return self._context
|
|
|
|
@context.setter
|
|
def context(self, value: str) -> None:
|
|
self._context = value
|
|
|
|
def load(self) -> None:
|
|
import mlx_whisper
|
|
|
|
log.info("Loading model %s via mlx-whisper", self.model)
|
|
try:
|
|
mlx_whisper.transcribe(
|
|
np.zeros(16_000, dtype=np.float32),
|
|
path_or_hf_repo=self.model,
|
|
)
|
|
self._loaded = True
|
|
log.info("Model ready")
|
|
except Exception:
|
|
log.error("Failed to load model %s", self.model, exc_info=True)
|
|
raise
|
|
|
|
def transcribe(self, audio: np.ndarray) -> str:
|
|
if not self._loaded:
|
|
self.load()
|
|
if audio.size == 0:
|
|
return ""
|
|
|
|
# Skip audio that's too short (<1s) or too quiet — Whisper hallucinates
|
|
# punctuation like "!" on silence/noise.
|
|
duration = audio.size / 16_000
|
|
energy = float(np.sqrt(np.mean(audio ** 2)))
|
|
log.debug("Audio: %.1fs, RMS energy: %.6f", duration, energy)
|
|
if duration < 1.0 or energy < self.silence_threshold:
|
|
log.debug("Audio too short or too quiet, skipping transcription")
|
|
return ""
|
|
|
|
import mlx_whisper
|
|
|
|
kwargs = {
|
|
"path_or_hf_repo": self.model,
|
|
}
|
|
if self._context:
|
|
kwargs["initial_prompt"] = self._context
|
|
if self.language != "auto":
|
|
kwargs["language"] = self.language
|
|
|
|
result = mlx_whisper.transcribe(audio, **kwargs)
|
|
return result["text"].strip()
|