feat: add llm post processing support.
This commit is contained in:
198
calliope/app.py
198
calliope/app.py
@@ -18,6 +18,7 @@ import rumps
|
||||
from calliope import config as config_mod
|
||||
from calliope.recorder import Recorder
|
||||
from calliope.transcriber import Transcriber
|
||||
from calliope.postprocessor import Postprocessor
|
||||
from calliope.typer import type_text, type_text_clipboard
|
||||
from calliope.hotkeys import HotkeyListener
|
||||
from calliope.overlay import WaveformOverlay
|
||||
@@ -43,6 +44,14 @@ class CalliopeApp(rumps.App):
|
||||
self.transcriber.context = cfg.get("context", "")
|
||||
self.transcriber.language = cfg.get("language", "auto")
|
||||
|
||||
# Post-processing
|
||||
pp_cfg = cfg.get("postprocessing", {})
|
||||
self.postprocessor: Postprocessor | None = None
|
||||
if pp_cfg.get("enabled") and pp_cfg.get("model"):
|
||||
self.postprocessor = Postprocessor(
|
||||
system_prompt=pp_cfg.get("system_prompt", ""),
|
||||
)
|
||||
|
||||
self._recording = False
|
||||
self._rec_lock = threading.Lock()
|
||||
self._rec_start_time: float | None = None
|
||||
@@ -78,6 +87,10 @@ class CalliopeApp(rumps.App):
|
||||
self._mic_menu = rumps.MenuItem("Microphone")
|
||||
self._build_mic_menu()
|
||||
|
||||
# Post-Processing submenu
|
||||
self._pp_menu = rumps.MenuItem("Post-Processing")
|
||||
self._build_pp_menu()
|
||||
|
||||
# Typing mode submenu
|
||||
self._typing_menu = rumps.MenuItem("Typing Mode")
|
||||
current_mode = cfg.get("typing_mode", "char")
|
||||
@@ -98,6 +111,7 @@ class CalliopeApp(rumps.App):
|
||||
self._model_menu,
|
||||
self._mic_menu,
|
||||
self._typing_menu,
|
||||
self._pp_menu,
|
||||
None,
|
||||
quit_item,
|
||||
]
|
||||
@@ -138,6 +152,10 @@ class CalliopeApp(rumps.App):
|
||||
self.status_item.title = self._ready_status()
|
||||
self.hotkeys.start()
|
||||
log.info("Model loaded, hotkeys active")
|
||||
# Load postprocessor if enabled
|
||||
pp_cfg = self.cfg.get("postprocessing", {})
|
||||
if pp_cfg.get("enabled") and pp_cfg.get("model"):
|
||||
self._ensure_postprocessor(pp_cfg["model"])
|
||||
except Exception:
|
||||
log.error("Failed to load model", exc_info=True)
|
||||
self.status_item.title = "Status: Model load failed"
|
||||
@@ -353,6 +371,14 @@ class CalliopeApp(rumps.App):
|
||||
self.status_item.title = self._ready_status()
|
||||
self._notify("Calliope", "", "No speech detected — audio too short or too quiet")
|
||||
return
|
||||
# LLM post-processing
|
||||
pp_cfg = self.cfg.get("postprocessing", {})
|
||||
if pp_cfg.get("enabled") and self.postprocessor and self.postprocessor._model is not None:
|
||||
try:
|
||||
self.status_item.title = "Status: Post-processing..."
|
||||
text = self.postprocessor.process(text)
|
||||
except Exception:
|
||||
log.error("Post-processing failed, using raw transcription", exc_info=True)
|
||||
if text:
|
||||
def _do_type():
|
||||
try:
|
||||
@@ -377,7 +403,179 @@ class CalliopeApp(rumps.App):
|
||||
self.title = "\U0001f3a4" # 🎤
|
||||
self._transcribe_done.set()
|
||||
|
||||
# ── Post-Processing ───────────────────────────────────────────
|
||||
|
||||
def _build_pp_menu(self) -> None:
|
||||
if self._pp_menu._menu is not None:
|
||||
self._pp_menu.clear()
|
||||
pp_cfg = self.cfg.get("postprocessing", {})
|
||||
enabled = pp_cfg.get("enabled", False)
|
||||
active_model = pp_cfg.get("model")
|
||||
models = pp_cfg.get("models", [])
|
||||
|
||||
# Enable/disable toggle
|
||||
toggle_label = "Disable Post-Processing" if enabled else "Enable Post-Processing"
|
||||
self._pp_menu.add(rumps.MenuItem(toggle_label, callback=self._on_pp_toggle))
|
||||
self._pp_menu.add(None) # separator
|
||||
|
||||
# Downloaded models
|
||||
if models:
|
||||
for m in models:
|
||||
short = m.split("/")[-1]
|
||||
prefix = "\u2713 " if m == active_model else " "
|
||||
item = rumps.MenuItem(f"{prefix}{short}", callback=self._on_pp_model_select)
|
||||
item._pp_model_id = m
|
||||
self._pp_menu.add(item)
|
||||
self._pp_menu.add(None)
|
||||
|
||||
self._pp_menu.add(rumps.MenuItem("Download Model...", callback=self._on_pp_download))
|
||||
self._pp_menu.add(rumps.MenuItem("Edit System Prompt...", callback=self._on_pp_edit_prompt))
|
||||
if models:
|
||||
self._pp_menu.add(rumps.MenuItem("Delete Model...", callback=self._on_pp_delete))
|
||||
|
||||
def _on_pp_toggle(self, sender) -> None:
|
||||
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||
enabled = not pp_cfg.get("enabled", False)
|
||||
pp_cfg["enabled"] = enabled
|
||||
config_mod.save(self.cfg)
|
||||
if enabled and pp_cfg.get("model"):
|
||||
self._ensure_postprocessor(pp_cfg["model"])
|
||||
elif not enabled:
|
||||
self._release_postprocessor()
|
||||
self._build_pp_menu()
|
||||
log.info("Post-processing %s", "enabled" if enabled else "disabled")
|
||||
|
||||
def _on_pp_model_select(self, sender) -> None:
|
||||
model_id = sender._pp_model_id
|
||||
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||
if model_id == pp_cfg.get("model"):
|
||||
return
|
||||
pp_cfg["model"] = model_id
|
||||
config_mod.save(self.cfg)
|
||||
if pp_cfg.get("enabled"):
|
||||
self._ensure_postprocessor(model_id)
|
||||
self._build_pp_menu()
|
||||
log.info("Post-processing model set to %s", model_id)
|
||||
|
||||
def _on_pp_download(self, sender) -> None:
|
||||
self._activate_app()
|
||||
response = rumps.Window(
|
||||
message="Enter a HuggingFace MLX model repo ID.\n\n"
|
||||
"Example: mlx-community/Qwen2.5-0.5B-Instruct-4bit",
|
||||
title="Download MLX Model",
|
||||
default_text="mlx-community/Qwen2.5-0.5B-Instruct-4bit",
|
||||
ok="Download",
|
||||
cancel="Cancel",
|
||||
dimensions=(320, 24),
|
||||
).run()
|
||||
self._deactivate_app()
|
||||
if response.clicked != 1:
|
||||
return
|
||||
repo = response.text.strip()
|
||||
if not repo:
|
||||
return
|
||||
self._notify("Calliope", "", f"Downloading {repo}...")
|
||||
|
||||
def _do_download():
|
||||
try:
|
||||
import huggingface_hub.constants as hf_constants
|
||||
os.environ["HF_HUB_OFFLINE"] = "0"
|
||||
hf_constants.HF_HUB_OFFLINE = False
|
||||
Postprocessor.download(repo)
|
||||
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||
if repo not in pp_cfg.setdefault("models", []):
|
||||
pp_cfg["models"].append(repo)
|
||||
if not pp_cfg.get("model"):
|
||||
pp_cfg["model"] = repo
|
||||
config_mod.save(self.cfg)
|
||||
self._build_pp_menu()
|
||||
self._notify("Calliope", "", f"Model downloaded: {repo}")
|
||||
except Exception:
|
||||
log.error("Failed to download %s", repo, exc_info=True)
|
||||
self._notify("Calliope", "Error", f"Failed to download {repo}")
|
||||
finally:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
hf_constants.HF_HUB_OFFLINE = True
|
||||
|
||||
threading.Thread(target=_do_download, daemon=True).start()
|
||||
|
||||
def _on_pp_edit_prompt(self, sender) -> None:
|
||||
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||
current = pp_cfg.get("system_prompt", "")
|
||||
self._activate_app()
|
||||
response = rumps.Window(
|
||||
message="System prompt sent to the LLM before your transcription:",
|
||||
title="Edit System Prompt",
|
||||
default_text=current,
|
||||
ok="Save",
|
||||
cancel="Cancel",
|
||||
dimensions=(320, 120),
|
||||
).run()
|
||||
self._deactivate_app()
|
||||
if response.clicked != 1:
|
||||
return
|
||||
pp_cfg["system_prompt"] = response.text.strip()
|
||||
config_mod.save(self.cfg)
|
||||
if self.postprocessor:
|
||||
from calliope.postprocessor import DEFAULT_SYSTEM_PROMPT
|
||||
self.postprocessor.system_prompt = pp_cfg["system_prompt"] or DEFAULT_SYSTEM_PROMPT
|
||||
log.info("Post-processing system prompt updated")
|
||||
|
||||
def _on_pp_delete(self, sender) -> None:
|
||||
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||
models = pp_cfg.get("models", [])
|
||||
if not models:
|
||||
return
|
||||
self._activate_app()
|
||||
response = rumps.Window(
|
||||
message="Enter the repo ID of the model to remove from Calliope:\n\n"
|
||||
+ "\n".join(f" • {m}" for m in models),
|
||||
title="Delete Model",
|
||||
default_text="",
|
||||
ok="Delete",
|
||||
cancel="Cancel",
|
||||
dimensions=(320, 24),
|
||||
).run()
|
||||
self._deactivate_app()
|
||||
if response.clicked != 1:
|
||||
return
|
||||
repo = response.text.strip()
|
||||
if repo not in models:
|
||||
return
|
||||
models.remove(repo)
|
||||
if pp_cfg.get("model") == repo:
|
||||
pp_cfg["model"] = models[0] if models else None
|
||||
if not models:
|
||||
pp_cfg["enabled"] = False
|
||||
self._release_postprocessor()
|
||||
config_mod.save(self.cfg)
|
||||
self._build_pp_menu()
|
||||
log.info("Removed model %s", repo)
|
||||
|
||||
def _ensure_postprocessor(self, model_id: str) -> None:
|
||||
"""Load the postprocessor model in a background thread."""
|
||||
def _load():
|
||||
try:
|
||||
if self.postprocessor is None:
|
||||
pp_cfg = self.cfg.get("postprocessing", {})
|
||||
self.postprocessor = Postprocessor(
|
||||
system_prompt=pp_cfg.get("system_prompt", ""),
|
||||
)
|
||||
self.postprocessor.unload()
|
||||
self.postprocessor.load(model_id)
|
||||
except Exception:
|
||||
log.error("Failed to load postprocessor %s", model_id, exc_info=True)
|
||||
self._notify("Calliope", "Error", f"Failed to load LLM: {model_id}")
|
||||
|
||||
threading.Thread(target=_load, daemon=True).start()
|
||||
|
||||
def _release_postprocessor(self) -> None:
|
||||
if self.postprocessor is not None:
|
||||
self.postprocessor.unload()
|
||||
self.postprocessor = None
|
||||
|
||||
def _on_quit(self, sender) -> None:
|
||||
self._release_postprocessor()
|
||||
self.hotkeys.stop()
|
||||
self.recorder.stop()
|
||||
# Stop overlay timers synchronously to avoid retain cycles on quit.
|
||||
|
||||
@@ -26,6 +26,12 @@ DEFAULTS: dict[str, Any] = {
|
||||
"silence_threshold": 0.005, # RMS energy below which audio is considered silence
|
||||
"notifications": True, # show macOS notifications
|
||||
"typing_delay": 0.005, # seconds between keystrokes in char mode
|
||||
"postprocessing": {
|
||||
"enabled": False,
|
||||
"model": None, # active model HF repo id
|
||||
"models": [], # list of downloaded model repo ids
|
||||
"system_prompt": "Fix grammar and punctuation in the following dictated text. Output only the corrected text, nothing else.",
|
||||
},
|
||||
}
|
||||
|
||||
LANGUAGES: dict[str, str] = {
|
||||
|
||||
72
calliope/postprocessor.py
Normal file
72
calliope/postprocessor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""LLM post-processing of transcriptions using MLX on Apple Silicon."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a speech-to-text post-processor. Your sole job is to clean up "
|
||||
"raw transcriptions. Fix punctuation, capitalization, and obvious "
|
||||
"mistranscriptions. Do not add, remove, or rephrase any words beyond "
|
||||
"what is necessary for correctness. Output ONLY the corrected text with "
|
||||
"no commentary, explanations, or prefixes."
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor:
|
||||
def __init__(self, system_prompt: str = ""):
|
||||
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._model_id: str | None = None
|
||||
|
||||
def load(self, model_id: str) -> None:
|
||||
from mlx_lm import load
|
||||
|
||||
log.info("Loading MLX model %s", model_id)
|
||||
self._model, self._tokenizer = load(model_id)
|
||||
self._model_id = model_id
|
||||
log.info("MLX model ready")
|
||||
|
||||
def process(self, text: str) -> str:
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Postprocessor model not loaded")
|
||||
|
||||
from mlx_lm import generate
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
prompt = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
output_budget = len(text) * 2 + 100
|
||||
# Allow extra headroom for reasoning/thinking tokens (stripped later).
|
||||
reasoning_budget = 2048
|
||||
result = generate(
|
||||
self._model, self._tokenizer, prompt=prompt, max_tokens=output_budget + reasoning_budget
|
||||
)
|
||||
result = re.sub(r"<think>[\s\S]*?</think>", "", result)
|
||||
result = result.strip()
|
||||
log.debug("Post-processing input: %s", text)
|
||||
log.debug("Post-processing output: %s", result)
|
||||
return result
|
||||
|
||||
def unload(self) -> None:
|
||||
import gc
|
||||
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._model_id = None
|
||||
gc.collect()
|
||||
log.info("MLX model unloaded")
|
||||
|
||||
@staticmethod
|
||||
def download(hf_repo: str) -> None:
|
||||
log.info("Downloading MLX model %s", hf_repo)
|
||||
snapshot_download(hf_repo)
|
||||
log.info("Download complete: %s", hf_repo)
|
||||
@@ -99,6 +99,33 @@ def run() -> dict:
|
||||
text = transcriber.transcribe(audio)
|
||||
console.print(f"[green]Result:[/green] {text or '(no speech detected)'}")
|
||||
|
||||
# ── LLM Post-Processing ─────────────────────────────────────────
|
||||
console.print("\n[bold]LLM Post-Processing (optional)[/bold]")
|
||||
console.print(" Clean up grammar & punctuation using a local MLX language model.")
|
||||
if Confirm.ask("Enable LLM post-processing?", default=False):
|
||||
default_llm = "mlx-community/Qwen2.5-0.5B-Instruct-4bit"
|
||||
llm_repo = Prompt.ask("MLX model repo", default=default_llm)
|
||||
console.print(f"Downloading [cyan]{llm_repo}[/cyan]...")
|
||||
|
||||
from calliope.postprocessor import Postprocessor
|
||||
|
||||
with Progress() as progress:
|
||||
task = progress.add_task("Downloading model...", total=None)
|
||||
Postprocessor.download(llm_repo)
|
||||
progress.update(task, completed=100, total=100)
|
||||
|
||||
console.print("[green]Model downloaded.[/green]")
|
||||
|
||||
pp_cfg = cfg.setdefault("postprocessing", {})
|
||||
pp_cfg["enabled"] = True
|
||||
pp_cfg["model"] = llm_repo
|
||||
pp_cfg["models"] = [llm_repo]
|
||||
|
||||
default_prompt = config.DEFAULTS["postprocessing"]["system_prompt"]
|
||||
current_prompt = pp_cfg.get("system_prompt", default_prompt)
|
||||
if not Confirm.ask(f"Use default system prompt?\n \"{current_prompt}\"", default=True):
|
||||
pp_cfg["system_prompt"] = Prompt.ask("System prompt")
|
||||
|
||||
# ── Save ─────────────────────────────────────────────────────────
|
||||
config.save(cfg)
|
||||
console.print(f"\n[green]Config saved to {config.CONFIG_PATH}[/green]")
|
||||
|
||||
@@ -21,6 +21,9 @@ dependencies = [
|
||||
"rich>=13.0.0",
|
||||
"click>=8.1.0",
|
||||
"pyyaml>=6.0",
|
||||
"mlx>=0.16.0",
|
||||
"mlx-lm>=0.14.0",
|
||||
"huggingface-hub>=0.20.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
Reference in New Issue
Block a user