diff --git a/calliope/app.py b/calliope/app.py index 087441a..38df317 100644 --- a/calliope/app.py +++ b/calliope/app.py @@ -18,6 +18,7 @@ import rumps from calliope import config as config_mod from calliope.recorder import Recorder from calliope.transcriber import Transcriber +from calliope.postprocessor import Postprocessor from calliope.typer import type_text, type_text_clipboard from calliope.hotkeys import HotkeyListener from calliope.overlay import WaveformOverlay @@ -43,6 +44,14 @@ class CalliopeApp(rumps.App): self.transcriber.context = cfg.get("context", "") self.transcriber.language = cfg.get("language", "auto") + # Post-processing + pp_cfg = cfg.get("postprocessing", {}) + self.postprocessor: Postprocessor | None = None + if pp_cfg.get("enabled") and pp_cfg.get("model"): + self.postprocessor = Postprocessor( + system_prompt=pp_cfg.get("system_prompt", ""), + ) + self._recording = False self._rec_lock = threading.Lock() self._rec_start_time: float | None = None @@ -78,6 +87,10 @@ class CalliopeApp(rumps.App): self._mic_menu = rumps.MenuItem("Microphone") self._build_mic_menu() + # Post-Processing submenu + self._pp_menu = rumps.MenuItem("Post-Processing") + self._build_pp_menu() + # Typing mode submenu self._typing_menu = rumps.MenuItem("Typing Mode") current_mode = cfg.get("typing_mode", "char") @@ -98,6 +111,7 @@ class CalliopeApp(rumps.App): self._model_menu, self._mic_menu, self._typing_menu, + self._pp_menu, None, quit_item, ] @@ -138,6 +152,10 @@ class CalliopeApp(rumps.App): self.status_item.title = self._ready_status() self.hotkeys.start() log.info("Model loaded, hotkeys active") + # Load postprocessor if enabled + pp_cfg = self.cfg.get("postprocessing", {}) + if pp_cfg.get("enabled") and pp_cfg.get("model"): + self._ensure_postprocessor(pp_cfg["model"]) except Exception: log.error("Failed to load model", exc_info=True) self.status_item.title = "Status: Model load failed" @@ -353,6 +371,14 @@ class CalliopeApp(rumps.App): self.status_item.title = self._ready_status() self._notify("Calliope", "", "No speech detected — audio too short or too quiet") return + # LLM post-processing + pp_cfg = self.cfg.get("postprocessing", {}) + if pp_cfg.get("enabled") and self.postprocessor and self.postprocessor._model is not None: + try: + self.status_item.title = "Status: Post-processing..." + text = self.postprocessor.process(text) + except Exception: + log.error("Post-processing failed, using raw transcription", exc_info=True) if text: def _do_type(): try: @@ -377,7 +403,179 @@ class CalliopeApp(rumps.App): self.title = "\U0001f3a4" # 🎤 self._transcribe_done.set() + # ── Post-Processing ─────────────────────────────────────────── + + def _build_pp_menu(self) -> None: + if self._pp_menu._menu is not None: + self._pp_menu.clear() + pp_cfg = self.cfg.get("postprocessing", {}) + enabled = pp_cfg.get("enabled", False) + active_model = pp_cfg.get("model") + models = pp_cfg.get("models", []) + + # Enable/disable toggle + toggle_label = "Disable Post-Processing" if enabled else "Enable Post-Processing" + self._pp_menu.add(rumps.MenuItem(toggle_label, callback=self._on_pp_toggle)) + self._pp_menu.add(None) # separator + + # Downloaded models + if models: + for m in models: + short = m.split("/")[-1] + prefix = "\u2713 " if m == active_model else " " + item = rumps.MenuItem(f"{prefix}{short}", callback=self._on_pp_model_select) + item._pp_model_id = m + self._pp_menu.add(item) + self._pp_menu.add(None) + + self._pp_menu.add(rumps.MenuItem("Download Model...", callback=self._on_pp_download)) + self._pp_menu.add(rumps.MenuItem("Edit System Prompt...", callback=self._on_pp_edit_prompt)) + if models: + self._pp_menu.add(rumps.MenuItem("Delete Model...", callback=self._on_pp_delete)) + + def _on_pp_toggle(self, sender) -> None: + pp_cfg = self.cfg.setdefault("postprocessing", {}) + enabled = not pp_cfg.get("enabled", False) + pp_cfg["enabled"] = enabled + config_mod.save(self.cfg) + if enabled and pp_cfg.get("model"): + self._ensure_postprocessor(pp_cfg["model"]) + elif not enabled: + self._release_postprocessor() + self._build_pp_menu() + log.info("Post-processing %s", "enabled" if enabled else "disabled") + + def _on_pp_model_select(self, sender) -> None: + model_id = sender._pp_model_id + pp_cfg = self.cfg.setdefault("postprocessing", {}) + if model_id == pp_cfg.get("model"): + return + pp_cfg["model"] = model_id + config_mod.save(self.cfg) + if pp_cfg.get("enabled"): + self._ensure_postprocessor(model_id) + self._build_pp_menu() + log.info("Post-processing model set to %s", model_id) + + def _on_pp_download(self, sender) -> None: + self._activate_app() + response = rumps.Window( + message="Enter a HuggingFace MLX model repo ID.\n\n" + "Example: mlx-community/Qwen2.5-0.5B-Instruct-4bit", + title="Download MLX Model", + default_text="mlx-community/Qwen2.5-0.5B-Instruct-4bit", + ok="Download", + cancel="Cancel", + dimensions=(320, 24), + ).run() + self._deactivate_app() + if response.clicked != 1: + return + repo = response.text.strip() + if not repo: + return + self._notify("Calliope", "", f"Downloading {repo}...") + + def _do_download(): + try: + import huggingface_hub.constants as hf_constants + os.environ["HF_HUB_OFFLINE"] = "0" + hf_constants.HF_HUB_OFFLINE = False + Postprocessor.download(repo) + pp_cfg = self.cfg.setdefault("postprocessing", {}) + if repo not in pp_cfg.setdefault("models", []): + pp_cfg["models"].append(repo) + if not pp_cfg.get("model"): + pp_cfg["model"] = repo + config_mod.save(self.cfg) + self._build_pp_menu() + self._notify("Calliope", "", f"Model downloaded: {repo}") + except Exception: + log.error("Failed to download %s", repo, exc_info=True) + self._notify("Calliope", "Error", f"Failed to download {repo}") + finally: + os.environ["HF_HUB_OFFLINE"] = "1" + hf_constants.HF_HUB_OFFLINE = True + + threading.Thread(target=_do_download, daemon=True).start() + + def _on_pp_edit_prompt(self, sender) -> None: + pp_cfg = self.cfg.setdefault("postprocessing", {}) + current = pp_cfg.get("system_prompt", "") + self._activate_app() + response = rumps.Window( + message="System prompt sent to the LLM before your transcription:", + title="Edit System Prompt", + default_text=current, + ok="Save", + cancel="Cancel", + dimensions=(320, 120), + ).run() + self._deactivate_app() + if response.clicked != 1: + return + pp_cfg["system_prompt"] = response.text.strip() + config_mod.save(self.cfg) + if self.postprocessor: + from calliope.postprocessor import DEFAULT_SYSTEM_PROMPT + self.postprocessor.system_prompt = pp_cfg["system_prompt"] or DEFAULT_SYSTEM_PROMPT + log.info("Post-processing system prompt updated") + + def _on_pp_delete(self, sender) -> None: + pp_cfg = self.cfg.setdefault("postprocessing", {}) + models = pp_cfg.get("models", []) + if not models: + return + self._activate_app() + response = rumps.Window( + message="Enter the repo ID of the model to remove from Calliope:\n\n" + + "\n".join(f" • {m}" for m in models), + title="Delete Model", + default_text="", + ok="Delete", + cancel="Cancel", + dimensions=(320, 24), + ).run() + self._deactivate_app() + if response.clicked != 1: + return + repo = response.text.strip() + if repo not in models: + return + models.remove(repo) + if pp_cfg.get("model") == repo: + pp_cfg["model"] = models[0] if models else None + if not models: + pp_cfg["enabled"] = False + self._release_postprocessor() + config_mod.save(self.cfg) + self._build_pp_menu() + log.info("Removed model %s", repo) + + def _ensure_postprocessor(self, model_id: str) -> None: + """Load the postprocessor model in a background thread.""" + def _load(): + try: + if self.postprocessor is None: + pp_cfg = self.cfg.get("postprocessing", {}) + self.postprocessor = Postprocessor( + system_prompt=pp_cfg.get("system_prompt", ""), + ) + self.postprocessor.unload() + self.postprocessor.load(model_id) + except Exception: + log.error("Failed to load postprocessor %s", model_id, exc_info=True) + self._notify("Calliope", "Error", f"Failed to load LLM: {model_id}") + + threading.Thread(target=_load, daemon=True).start() + + def _release_postprocessor(self) -> None: + if self.postprocessor is not None: + self.postprocessor.unload() + self.postprocessor = None + def _on_quit(self, sender) -> None: + self._release_postprocessor() self.hotkeys.stop() self.recorder.stop() # Stop overlay timers synchronously to avoid retain cycles on quit. diff --git a/calliope/config.py b/calliope/config.py index 4d578ed..6ffa6f3 100644 --- a/calliope/config.py +++ b/calliope/config.py @@ -26,6 +26,12 @@ DEFAULTS: dict[str, Any] = { "silence_threshold": 0.005, # RMS energy below which audio is considered silence "notifications": True, # show macOS notifications "typing_delay": 0.005, # seconds between keystrokes in char mode + "postprocessing": { + "enabled": False, + "model": None, # active model HF repo id + "models": [], # list of downloaded model repo ids + "system_prompt": "Fix grammar and punctuation in the following dictated text. Output only the corrected text, nothing else.", + }, } LANGUAGES: dict[str, str] = { diff --git a/calliope/postprocessor.py b/calliope/postprocessor.py new file mode 100644 index 0000000..7f1fa72 --- /dev/null +++ b/calliope/postprocessor.py @@ -0,0 +1,72 @@ +"""LLM post-processing of transcriptions using MLX on Apple Silicon.""" + +import logging +import re + +from huggingface_hub import snapshot_download + +log = logging.getLogger(__name__) + +DEFAULT_SYSTEM_PROMPT = ( + "You are a speech-to-text post-processor. Your sole job is to clean up " + "raw transcriptions. Fix punctuation, capitalization, and obvious " + "mistranscriptions. Do not add, remove, or rephrase any words beyond " + "what is necessary for correctness. Output ONLY the corrected text with " + "no commentary, explanations, or prefixes." +) + + +class Postprocessor: + def __init__(self, system_prompt: str = ""): + self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT + self._model = None + self._tokenizer = None + self._model_id: str | None = None + + def load(self, model_id: str) -> None: + from mlx_lm import load + + log.info("Loading MLX model %s", model_id) + self._model, self._tokenizer = load(model_id) + self._model_id = model_id + log.info("MLX model ready") + + def process(self, text: str) -> str: + if self._model is None or self._tokenizer is None: + raise RuntimeError("Postprocessor model not loaded") + + from mlx_lm import generate + + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": text}, + ] + prompt = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + output_budget = len(text) * 2 + 100 + # Allow extra headroom for reasoning/thinking tokens (stripped later). + reasoning_budget = 2048 + result = generate( + self._model, self._tokenizer, prompt=prompt, max_tokens=output_budget + reasoning_budget + ) + result = re.sub(r"[\s\S]*?", "", result) + result = result.strip() + log.debug("Post-processing input: %s", text) + log.debug("Post-processing output: %s", result) + return result + + def unload(self) -> None: + import gc + + self._model = None + self._tokenizer = None + self._model_id = None + gc.collect() + log.info("MLX model unloaded") + + @staticmethod + def download(hf_repo: str) -> None: + log.info("Downloading MLX model %s", hf_repo) + snapshot_download(hf_repo) + log.info("Download complete: %s", hf_repo) diff --git a/calliope/setup_wizard.py b/calliope/setup_wizard.py index 186ab6a..ae61fd1 100644 --- a/calliope/setup_wizard.py +++ b/calliope/setup_wizard.py @@ -99,6 +99,33 @@ def run() -> dict: text = transcriber.transcribe(audio) console.print(f"[green]Result:[/green] {text or '(no speech detected)'}") + # ── LLM Post-Processing ───────────────────────────────────────── + console.print("\n[bold]LLM Post-Processing (optional)[/bold]") + console.print(" Clean up grammar & punctuation using a local MLX language model.") + if Confirm.ask("Enable LLM post-processing?", default=False): + default_llm = "mlx-community/Qwen2.5-0.5B-Instruct-4bit" + llm_repo = Prompt.ask("MLX model repo", default=default_llm) + console.print(f"Downloading [cyan]{llm_repo}[/cyan]...") + + from calliope.postprocessor import Postprocessor + + with Progress() as progress: + task = progress.add_task("Downloading model...", total=None) + Postprocessor.download(llm_repo) + progress.update(task, completed=100, total=100) + + console.print("[green]Model downloaded.[/green]") + + pp_cfg = cfg.setdefault("postprocessing", {}) + pp_cfg["enabled"] = True + pp_cfg["model"] = llm_repo + pp_cfg["models"] = [llm_repo] + + default_prompt = config.DEFAULTS["postprocessing"]["system_prompt"] + current_prompt = pp_cfg.get("system_prompt", default_prompt) + if not Confirm.ask(f"Use default system prompt?\n \"{current_prompt}\"", default=True): + pp_cfg["system_prompt"] = Prompt.ask("System prompt") + # ── Save ───────────────────────────────────────────────────────── config.save(cfg) console.print(f"\n[green]Config saved to {config.CONFIG_PATH}[/green]") diff --git a/pyproject.toml b/pyproject.toml index 08a8760..36b1afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ dependencies = [ "rich>=13.0.0", "click>=8.1.0", "pyyaml>=6.0", + "mlx>=0.16.0", + "mlx-lm>=0.14.0", + "huggingface-hub>=0.20.0", ] [project.scripts]