feat: add llm post processing support.
This commit is contained in:
198
calliope/app.py
198
calliope/app.py
@@ -18,6 +18,7 @@ import rumps
|
|||||||
from calliope import config as config_mod
|
from calliope import config as config_mod
|
||||||
from calliope.recorder import Recorder
|
from calliope.recorder import Recorder
|
||||||
from calliope.transcriber import Transcriber
|
from calliope.transcriber import Transcriber
|
||||||
|
from calliope.postprocessor import Postprocessor
|
||||||
from calliope.typer import type_text, type_text_clipboard
|
from calliope.typer import type_text, type_text_clipboard
|
||||||
from calliope.hotkeys import HotkeyListener
|
from calliope.hotkeys import HotkeyListener
|
||||||
from calliope.overlay import WaveformOverlay
|
from calliope.overlay import WaveformOverlay
|
||||||
@@ -43,6 +44,14 @@ class CalliopeApp(rumps.App):
|
|||||||
self.transcriber.context = cfg.get("context", "")
|
self.transcriber.context = cfg.get("context", "")
|
||||||
self.transcriber.language = cfg.get("language", "auto")
|
self.transcriber.language = cfg.get("language", "auto")
|
||||||
|
|
||||||
|
# Post-processing
|
||||||
|
pp_cfg = cfg.get("postprocessing", {})
|
||||||
|
self.postprocessor: Postprocessor | None = None
|
||||||
|
if pp_cfg.get("enabled") and pp_cfg.get("model"):
|
||||||
|
self.postprocessor = Postprocessor(
|
||||||
|
system_prompt=pp_cfg.get("system_prompt", ""),
|
||||||
|
)
|
||||||
|
|
||||||
self._recording = False
|
self._recording = False
|
||||||
self._rec_lock = threading.Lock()
|
self._rec_lock = threading.Lock()
|
||||||
self._rec_start_time: float | None = None
|
self._rec_start_time: float | None = None
|
||||||
@@ -78,6 +87,10 @@ class CalliopeApp(rumps.App):
|
|||||||
self._mic_menu = rumps.MenuItem("Microphone")
|
self._mic_menu = rumps.MenuItem("Microphone")
|
||||||
self._build_mic_menu()
|
self._build_mic_menu()
|
||||||
|
|
||||||
|
# Post-Processing submenu
|
||||||
|
self._pp_menu = rumps.MenuItem("Post-Processing")
|
||||||
|
self._build_pp_menu()
|
||||||
|
|
||||||
# Typing mode submenu
|
# Typing mode submenu
|
||||||
self._typing_menu = rumps.MenuItem("Typing Mode")
|
self._typing_menu = rumps.MenuItem("Typing Mode")
|
||||||
current_mode = cfg.get("typing_mode", "char")
|
current_mode = cfg.get("typing_mode", "char")
|
||||||
@@ -98,6 +111,7 @@ class CalliopeApp(rumps.App):
|
|||||||
self._model_menu,
|
self._model_menu,
|
||||||
self._mic_menu,
|
self._mic_menu,
|
||||||
self._typing_menu,
|
self._typing_menu,
|
||||||
|
self._pp_menu,
|
||||||
None,
|
None,
|
||||||
quit_item,
|
quit_item,
|
||||||
]
|
]
|
||||||
@@ -138,6 +152,10 @@ class CalliopeApp(rumps.App):
|
|||||||
self.status_item.title = self._ready_status()
|
self.status_item.title = self._ready_status()
|
||||||
self.hotkeys.start()
|
self.hotkeys.start()
|
||||||
log.info("Model loaded, hotkeys active")
|
log.info("Model loaded, hotkeys active")
|
||||||
|
# Load postprocessor if enabled
|
||||||
|
pp_cfg = self.cfg.get("postprocessing", {})
|
||||||
|
if pp_cfg.get("enabled") and pp_cfg.get("model"):
|
||||||
|
self._ensure_postprocessor(pp_cfg["model"])
|
||||||
except Exception:
|
except Exception:
|
||||||
log.error("Failed to load model", exc_info=True)
|
log.error("Failed to load model", exc_info=True)
|
||||||
self.status_item.title = "Status: Model load failed"
|
self.status_item.title = "Status: Model load failed"
|
||||||
@@ -353,6 +371,14 @@ class CalliopeApp(rumps.App):
|
|||||||
self.status_item.title = self._ready_status()
|
self.status_item.title = self._ready_status()
|
||||||
self._notify("Calliope", "", "No speech detected — audio too short or too quiet")
|
self._notify("Calliope", "", "No speech detected — audio too short or too quiet")
|
||||||
return
|
return
|
||||||
|
# LLM post-processing
|
||||||
|
pp_cfg = self.cfg.get("postprocessing", {})
|
||||||
|
if pp_cfg.get("enabled") and self.postprocessor and self.postprocessor._model is not None:
|
||||||
|
try:
|
||||||
|
self.status_item.title = "Status: Post-processing..."
|
||||||
|
text = self.postprocessor.process(text)
|
||||||
|
except Exception:
|
||||||
|
log.error("Post-processing failed, using raw transcription", exc_info=True)
|
||||||
if text:
|
if text:
|
||||||
def _do_type():
|
def _do_type():
|
||||||
try:
|
try:
|
||||||
@@ -377,7 +403,179 @@ class CalliopeApp(rumps.App):
|
|||||||
self.title = "\U0001f3a4" # 🎤
|
self.title = "\U0001f3a4" # 🎤
|
||||||
self._transcribe_done.set()
|
self._transcribe_done.set()
|
||||||
|
|
||||||
|
# ── Post-Processing ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def _build_pp_menu(self) -> None:
|
||||||
|
if self._pp_menu._menu is not None:
|
||||||
|
self._pp_menu.clear()
|
||||||
|
pp_cfg = self.cfg.get("postprocessing", {})
|
||||||
|
enabled = pp_cfg.get("enabled", False)
|
||||||
|
active_model = pp_cfg.get("model")
|
||||||
|
models = pp_cfg.get("models", [])
|
||||||
|
|
||||||
|
# Enable/disable toggle
|
||||||
|
toggle_label = "Disable Post-Processing" if enabled else "Enable Post-Processing"
|
||||||
|
self._pp_menu.add(rumps.MenuItem(toggle_label, callback=self._on_pp_toggle))
|
||||||
|
self._pp_menu.add(None) # separator
|
||||||
|
|
||||||
|
# Downloaded models
|
||||||
|
if models:
|
||||||
|
for m in models:
|
||||||
|
short = m.split("/")[-1]
|
||||||
|
prefix = "\u2713 " if m == active_model else " "
|
||||||
|
item = rumps.MenuItem(f"{prefix}{short}", callback=self._on_pp_model_select)
|
||||||
|
item._pp_model_id = m
|
||||||
|
self._pp_menu.add(item)
|
||||||
|
self._pp_menu.add(None)
|
||||||
|
|
||||||
|
self._pp_menu.add(rumps.MenuItem("Download Model...", callback=self._on_pp_download))
|
||||||
|
self._pp_menu.add(rumps.MenuItem("Edit System Prompt...", callback=self._on_pp_edit_prompt))
|
||||||
|
if models:
|
||||||
|
self._pp_menu.add(rumps.MenuItem("Delete Model...", callback=self._on_pp_delete))
|
||||||
|
|
||||||
|
def _on_pp_toggle(self, sender) -> None:
|
||||||
|
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||||
|
enabled = not pp_cfg.get("enabled", False)
|
||||||
|
pp_cfg["enabled"] = enabled
|
||||||
|
config_mod.save(self.cfg)
|
||||||
|
if enabled and pp_cfg.get("model"):
|
||||||
|
self._ensure_postprocessor(pp_cfg["model"])
|
||||||
|
elif not enabled:
|
||||||
|
self._release_postprocessor()
|
||||||
|
self._build_pp_menu()
|
||||||
|
log.info("Post-processing %s", "enabled" if enabled else "disabled")
|
||||||
|
|
||||||
|
def _on_pp_model_select(self, sender) -> None:
|
||||||
|
model_id = sender._pp_model_id
|
||||||
|
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||||
|
if model_id == pp_cfg.get("model"):
|
||||||
|
return
|
||||||
|
pp_cfg["model"] = model_id
|
||||||
|
config_mod.save(self.cfg)
|
||||||
|
if pp_cfg.get("enabled"):
|
||||||
|
self._ensure_postprocessor(model_id)
|
||||||
|
self._build_pp_menu()
|
||||||
|
log.info("Post-processing model set to %s", model_id)
|
||||||
|
|
||||||
|
def _on_pp_download(self, sender) -> None:
|
||||||
|
self._activate_app()
|
||||||
|
response = rumps.Window(
|
||||||
|
message="Enter a HuggingFace MLX model repo ID.\n\n"
|
||||||
|
"Example: mlx-community/Qwen2.5-0.5B-Instruct-4bit",
|
||||||
|
title="Download MLX Model",
|
||||||
|
default_text="mlx-community/Qwen2.5-0.5B-Instruct-4bit",
|
||||||
|
ok="Download",
|
||||||
|
cancel="Cancel",
|
||||||
|
dimensions=(320, 24),
|
||||||
|
).run()
|
||||||
|
self._deactivate_app()
|
||||||
|
if response.clicked != 1:
|
||||||
|
return
|
||||||
|
repo = response.text.strip()
|
||||||
|
if not repo:
|
||||||
|
return
|
||||||
|
self._notify("Calliope", "", f"Downloading {repo}...")
|
||||||
|
|
||||||
|
def _do_download():
|
||||||
|
try:
|
||||||
|
import huggingface_hub.constants as hf_constants
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = "0"
|
||||||
|
hf_constants.HF_HUB_OFFLINE = False
|
||||||
|
Postprocessor.download(repo)
|
||||||
|
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||||
|
if repo not in pp_cfg.setdefault("models", []):
|
||||||
|
pp_cfg["models"].append(repo)
|
||||||
|
if not pp_cfg.get("model"):
|
||||||
|
pp_cfg["model"] = repo
|
||||||
|
config_mod.save(self.cfg)
|
||||||
|
self._build_pp_menu()
|
||||||
|
self._notify("Calliope", "", f"Model downloaded: {repo}")
|
||||||
|
except Exception:
|
||||||
|
log.error("Failed to download %s", repo, exc_info=True)
|
||||||
|
self._notify("Calliope", "Error", f"Failed to download {repo}")
|
||||||
|
finally:
|
||||||
|
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||||
|
hf_constants.HF_HUB_OFFLINE = True
|
||||||
|
|
||||||
|
threading.Thread(target=_do_download, daemon=True).start()
|
||||||
|
|
||||||
|
def _on_pp_edit_prompt(self, sender) -> None:
|
||||||
|
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||||
|
current = pp_cfg.get("system_prompt", "")
|
||||||
|
self._activate_app()
|
||||||
|
response = rumps.Window(
|
||||||
|
message="System prompt sent to the LLM before your transcription:",
|
||||||
|
title="Edit System Prompt",
|
||||||
|
default_text=current,
|
||||||
|
ok="Save",
|
||||||
|
cancel="Cancel",
|
||||||
|
dimensions=(320, 120),
|
||||||
|
).run()
|
||||||
|
self._deactivate_app()
|
||||||
|
if response.clicked != 1:
|
||||||
|
return
|
||||||
|
pp_cfg["system_prompt"] = response.text.strip()
|
||||||
|
config_mod.save(self.cfg)
|
||||||
|
if self.postprocessor:
|
||||||
|
from calliope.postprocessor import DEFAULT_SYSTEM_PROMPT
|
||||||
|
self.postprocessor.system_prompt = pp_cfg["system_prompt"] or DEFAULT_SYSTEM_PROMPT
|
||||||
|
log.info("Post-processing system prompt updated")
|
||||||
|
|
||||||
|
def _on_pp_delete(self, sender) -> None:
|
||||||
|
pp_cfg = self.cfg.setdefault("postprocessing", {})
|
||||||
|
models = pp_cfg.get("models", [])
|
||||||
|
if not models:
|
||||||
|
return
|
||||||
|
self._activate_app()
|
||||||
|
response = rumps.Window(
|
||||||
|
message="Enter the repo ID of the model to remove from Calliope:\n\n"
|
||||||
|
+ "\n".join(f" • {m}" for m in models),
|
||||||
|
title="Delete Model",
|
||||||
|
default_text="",
|
||||||
|
ok="Delete",
|
||||||
|
cancel="Cancel",
|
||||||
|
dimensions=(320, 24),
|
||||||
|
).run()
|
||||||
|
self._deactivate_app()
|
||||||
|
if response.clicked != 1:
|
||||||
|
return
|
||||||
|
repo = response.text.strip()
|
||||||
|
if repo not in models:
|
||||||
|
return
|
||||||
|
models.remove(repo)
|
||||||
|
if pp_cfg.get("model") == repo:
|
||||||
|
pp_cfg["model"] = models[0] if models else None
|
||||||
|
if not models:
|
||||||
|
pp_cfg["enabled"] = False
|
||||||
|
self._release_postprocessor()
|
||||||
|
config_mod.save(self.cfg)
|
||||||
|
self._build_pp_menu()
|
||||||
|
log.info("Removed model %s", repo)
|
||||||
|
|
||||||
|
def _ensure_postprocessor(self, model_id: str) -> None:
|
||||||
|
"""Load the postprocessor model in a background thread."""
|
||||||
|
def _load():
|
||||||
|
try:
|
||||||
|
if self.postprocessor is None:
|
||||||
|
pp_cfg = self.cfg.get("postprocessing", {})
|
||||||
|
self.postprocessor = Postprocessor(
|
||||||
|
system_prompt=pp_cfg.get("system_prompt", ""),
|
||||||
|
)
|
||||||
|
self.postprocessor.unload()
|
||||||
|
self.postprocessor.load(model_id)
|
||||||
|
except Exception:
|
||||||
|
log.error("Failed to load postprocessor %s", model_id, exc_info=True)
|
||||||
|
self._notify("Calliope", "Error", f"Failed to load LLM: {model_id}")
|
||||||
|
|
||||||
|
threading.Thread(target=_load, daemon=True).start()
|
||||||
|
|
||||||
|
def _release_postprocessor(self) -> None:
|
||||||
|
if self.postprocessor is not None:
|
||||||
|
self.postprocessor.unload()
|
||||||
|
self.postprocessor = None
|
||||||
|
|
||||||
def _on_quit(self, sender) -> None:
|
def _on_quit(self, sender) -> None:
|
||||||
|
self._release_postprocessor()
|
||||||
self.hotkeys.stop()
|
self.hotkeys.stop()
|
||||||
self.recorder.stop()
|
self.recorder.stop()
|
||||||
# Stop overlay timers synchronously to avoid retain cycles on quit.
|
# Stop overlay timers synchronously to avoid retain cycles on quit.
|
||||||
|
|||||||
@@ -26,6 +26,12 @@ DEFAULTS: dict[str, Any] = {
|
|||||||
"silence_threshold": 0.005, # RMS energy below which audio is considered silence
|
"silence_threshold": 0.005, # RMS energy below which audio is considered silence
|
||||||
"notifications": True, # show macOS notifications
|
"notifications": True, # show macOS notifications
|
||||||
"typing_delay": 0.005, # seconds between keystrokes in char mode
|
"typing_delay": 0.005, # seconds between keystrokes in char mode
|
||||||
|
"postprocessing": {
|
||||||
|
"enabled": False,
|
||||||
|
"model": None, # active model HF repo id
|
||||||
|
"models": [], # list of downloaded model repo ids
|
||||||
|
"system_prompt": "Fix grammar and punctuation in the following dictated text. Output only the corrected text, nothing else.",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
LANGUAGES: dict[str, str] = {
|
LANGUAGES: dict[str, str] = {
|
||||||
|
|||||||
72
calliope/postprocessor.py
Normal file
72
calliope/postprocessor.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""LLM post-processing of transcriptions using MLX on Apple Silicon."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are a speech-to-text post-processor. Your sole job is to clean up "
|
||||||
|
"raw transcriptions. Fix punctuation, capitalization, and obvious "
|
||||||
|
"mistranscriptions. Do not add, remove, or rephrase any words beyond "
|
||||||
|
"what is necessary for correctness. Output ONLY the corrected text with "
|
||||||
|
"no commentary, explanations, or prefixes."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Postprocessor:
|
||||||
|
def __init__(self, system_prompt: str = ""):
|
||||||
|
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||||
|
self._model = None
|
||||||
|
self._tokenizer = None
|
||||||
|
self._model_id: str | None = None
|
||||||
|
|
||||||
|
def load(self, model_id: str) -> None:
|
||||||
|
from mlx_lm import load
|
||||||
|
|
||||||
|
log.info("Loading MLX model %s", model_id)
|
||||||
|
self._model, self._tokenizer = load(model_id)
|
||||||
|
self._model_id = model_id
|
||||||
|
log.info("MLX model ready")
|
||||||
|
|
||||||
|
def process(self, text: str) -> str:
|
||||||
|
if self._model is None or self._tokenizer is None:
|
||||||
|
raise RuntimeError("Postprocessor model not loaded")
|
||||||
|
|
||||||
|
from mlx_lm import generate
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self.system_prompt},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
]
|
||||||
|
prompt = self._tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
output_budget = len(text) * 2 + 100
|
||||||
|
# Allow extra headroom for reasoning/thinking tokens (stripped later).
|
||||||
|
reasoning_budget = 2048
|
||||||
|
result = generate(
|
||||||
|
self._model, self._tokenizer, prompt=prompt, max_tokens=output_budget + reasoning_budget
|
||||||
|
)
|
||||||
|
result = re.sub(r"<think>[\s\S]*?</think>", "", result)
|
||||||
|
result = result.strip()
|
||||||
|
log.debug("Post-processing input: %s", text)
|
||||||
|
log.debug("Post-processing output: %s", result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
self._model = None
|
||||||
|
self._tokenizer = None
|
||||||
|
self._model_id = None
|
||||||
|
gc.collect()
|
||||||
|
log.info("MLX model unloaded")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download(hf_repo: str) -> None:
|
||||||
|
log.info("Downloading MLX model %s", hf_repo)
|
||||||
|
snapshot_download(hf_repo)
|
||||||
|
log.info("Download complete: %s", hf_repo)
|
||||||
@@ -99,6 +99,33 @@ def run() -> dict:
|
|||||||
text = transcriber.transcribe(audio)
|
text = transcriber.transcribe(audio)
|
||||||
console.print(f"[green]Result:[/green] {text or '(no speech detected)'}")
|
console.print(f"[green]Result:[/green] {text or '(no speech detected)'}")
|
||||||
|
|
||||||
|
# ── LLM Post-Processing ─────────────────────────────────────────
|
||||||
|
console.print("\n[bold]LLM Post-Processing (optional)[/bold]")
|
||||||
|
console.print(" Clean up grammar & punctuation using a local MLX language model.")
|
||||||
|
if Confirm.ask("Enable LLM post-processing?", default=False):
|
||||||
|
default_llm = "mlx-community/Qwen2.5-0.5B-Instruct-4bit"
|
||||||
|
llm_repo = Prompt.ask("MLX model repo", default=default_llm)
|
||||||
|
console.print(f"Downloading [cyan]{llm_repo}[/cyan]...")
|
||||||
|
|
||||||
|
from calliope.postprocessor import Postprocessor
|
||||||
|
|
||||||
|
with Progress() as progress:
|
||||||
|
task = progress.add_task("Downloading model...", total=None)
|
||||||
|
Postprocessor.download(llm_repo)
|
||||||
|
progress.update(task, completed=100, total=100)
|
||||||
|
|
||||||
|
console.print("[green]Model downloaded.[/green]")
|
||||||
|
|
||||||
|
pp_cfg = cfg.setdefault("postprocessing", {})
|
||||||
|
pp_cfg["enabled"] = True
|
||||||
|
pp_cfg["model"] = llm_repo
|
||||||
|
pp_cfg["models"] = [llm_repo]
|
||||||
|
|
||||||
|
default_prompt = config.DEFAULTS["postprocessing"]["system_prompt"]
|
||||||
|
current_prompt = pp_cfg.get("system_prompt", default_prompt)
|
||||||
|
if not Confirm.ask(f"Use default system prompt?\n \"{current_prompt}\"", default=True):
|
||||||
|
pp_cfg["system_prompt"] = Prompt.ask("System prompt")
|
||||||
|
|
||||||
# ── Save ─────────────────────────────────────────────────────────
|
# ── Save ─────────────────────────────────────────────────────────
|
||||||
config.save(cfg)
|
config.save(cfg)
|
||||||
console.print(f"\n[green]Config saved to {config.CONFIG_PATH}[/green]")
|
console.print(f"\n[green]Config saved to {config.CONFIG_PATH}[/green]")
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ dependencies = [
|
|||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
"click>=8.1.0",
|
"click>=8.1.0",
|
||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
|
"mlx>=0.16.0",
|
||||||
|
"mlx-lm>=0.14.0",
|
||||||
|
"huggingface-hub>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
Reference in New Issue
Block a user