diff --git a/.gitignore b/.gitignore index fe89e8d..af83629 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ venv .vscode .pytest_cache/ __pycache__/ +faiss + diff --git a/bertalign/aligner.py b/bertalign/aligner.py index a8ffed9..516b1f3 100644 --- a/bertalign/aligner.py +++ b/bertalign/aligner.py @@ -2,48 +2,86 @@ import numpy as np from bertalign import model from bertalign.corelib import * +from bertalign.chunk import ( + AlignmentResult, + AlignmentError, + TextChunk, + get_paragraph_indices, + validate_alignment, + process_alignments, +) from bertalign.utils import * + class Bertalign: - def __init__(self, - src, - tgt, - max_align=5, - top_k=3, - win=5, - skip=-0.1, - margin=True, - len_penalty=True, - is_split=False, - src_lang=None, - tgt_lang=None - ): - + """An automatic mulitlingual sentence aligner. + + Uses sentence-transformers to align sentences across languages through semantic similarity. + Performs two-step alignment: first finding 1-1 anchor points, then extracting more complex + alignments (1-many, many-1, many-many) within constrained search paths.""" + + def __init__( + self, + src, + tgt, + max_align=5, + top_k=3, + win=5, + skip=-0.1, + margin=True, + len_penalty=True, + is_split=False, + src_lang=None, + tgt_lang=None, + ): + """Initialize aligner with source and target texts. + + Args: + src (str | list[str]): Source text or list of sentences + tgt (str | list[str]): Target text or list of sentences + max_align (int, optional): Max sentences per alignment (default: 5) + top_k (int, optional): Candidate targets per source (default: 3) + win (int, optional): Window size for searching in second pass (default: 5) + skip (float, optional): Skip penalty (default: -0.1) + margin (bool, optional): Use margin in scoring (default: True) + len_penalty (bool, optional): Apply length penalty (default: True) + is_split (bool, optional): Input is pre-split into lines (default: False) + src_lang (str, optional): Source language code (auto-detected if None) + tgt_lang (str, optional): Target language code (auto-detected if None) + """ + self.max_align = max_align self.top_k = top_k self.win = win self.skip = skip self.margin = margin self.len_penalty = len_penalty - - src = clean_text(src) - tgt = clean_text(tgt) + src_lang = src_lang if src_lang is not None else detect_lang(src) tgt_lang = tgt_lang if tgt_lang is not None else detect_lang(tgt) - - if is_split: + + # src = clean_text(src) + # tgt = clean_text(tgt) + + if isinstance(src, list) and isinstance(tgt, list): + src_sents = src + tgt_sents = tgt + elif is_split: src_sents = src.splitlines() tgt_sents = tgt.splitlines() else: src_sents = split_sents(src, src_lang) tgt_sents = split_sents(tgt, tgt_lang) - + + src_sents = clean_text(src_sents) + tgt_sents = clean_text(tgt_sents) + src_num = len(src_sents) tgt_num = len(tgt_sents) - + src_lang = LANG.ISO[src_lang] tgt_lang = LANG.ISO[tgt_lang] - + print("Source language: {}, Number of sentences: {}".format(src_lang, src_num)) print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num)) @@ -53,6 +91,8 @@ class Bertalign: char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) + self.src = src + self.tgt = tgt self.src_lang = src_lang self.tgt_lang = tgt_lang self.src_sents = src_sents @@ -64,36 +104,108 @@ class Bertalign: self.char_ratio = char_ratio self.src_vecs = src_vecs self.tgt_vecs = tgt_vecs - + def align_sents(self): + """Execute two-pass sentence alignment + + 1. First pass: finds 1-1 alignments as anchors + 2. Second pass: finds complex n:m alignments between anchors + """ print("Performing first-step alignment ...") - D, I = find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k) - first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1 + D, I = find_top_k_sents(self.src_vecs[0, :], self.tgt_vecs[0, :], k=self.top_k) + first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1 first_w, first_path = find_first_search_path(self.src_num, self.tgt_num) - first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I) - first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types) - + first_pointers = first_pass_align( + self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I + ) + first_alignment = first_back_track( + self.src_num, + self.tgt_num, + first_pointers, + first_path, + first_alignment_types, + ) + print("Performing second-step alignment ...") second_alignment_types = get_alignment_types(self.max_align) - second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num) - second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens, - second_w, second_path, second_alignment_types, - self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty) - second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types) - - print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang)) + second_w, second_path = find_second_search_path( + first_alignment, self.win, self.src_num, self.tgt_num + ) + second_pointers = second_pass_align( + self.src_vecs, + self.tgt_vecs, + self.src_lens, + self.tgt_lens, + second_w, + second_path, + second_alignment_types, + self.char_ratio, + self.skip, + margin=self.margin, + len_penalty=self.len_penalty, + ) + second_alignment = second_back_track( + self.src_num, + self.tgt_num, + second_pointers, + second_path, + second_alignment_types, + ) + + print( + "Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format( + self.src_num, self.src_lang, self.tgt_num, self.tgt_lang + ) + ) self.result = second_alignment - + def print_sents(self): - for bead in (self.result): + """Print aligned sentences in parallel""" + for bead in self.result: src_line = self._get_line(bead[0], self.src_sents) tgt_line = self._get_line(bead[1], self.tgt_sents) print(src_line + "\n" + tgt_line + "\n") + def chunk(self, max_chars: int = 512) -> list[TextChunk]: + """Create aligned chunks respecting paragraph and sentence boundaries + + Args: + max_chars (int, optional): Maximum characters per chunk. Defaults to 512. + + Returns: + list[TextChunk]: List of TextChunk with aligned paragraphs + """ + source_paragraphs = self.src_sents + target_paragraphs = self.tgt_sents + + self.align_sents() + + if not isinstance(self.result, list): + raise AlignmentError("Invalid alignment result type") + + alignments: AlignmentResult = self.result + if not alignments: + raise AlignmentError("Aligner produced no results") + + # get paragraph mappings and validate + source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs) + target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs) + + validate_alignment(self.result, source_sent_count, target_sent_count) + + return process_alignments( + self.result, + source_paragraphs, + target_paragraphs, + source_par_map, + target_par_map, + max_chars, + ) + @staticmethod def _get_line(bead, lines): - line = '' + line = "" if len(bead) > 0: - line = ' '.join(lines[bead[0]:bead[-1]+1]) + line = " ".join(lines[bead[0] : bead[-1] + 1]) return line diff --git a/bertalign/chunk.py b/bertalign/chunk.py index d5fe854..530f2e8 100644 --- a/bertalign/chunk.py +++ b/bertalign/chunk.py @@ -1,4 +1,3 @@ -from bertalign.aligner import Bertalign from dataclasses import dataclass from typing import TypeAlias @@ -158,52 +157,3 @@ def should_create_new_chunk( return at_paragraph_boundary( source_indices, source_par_map ) and at_paragraph_boundary(target_indices, target_par_map) - - -def create_aligned_chunks( - source_text: str, target_text: str, max_chars: int = 500 -) -> list[TextChunk]: - """Create aligned chunks of text respecting paragraph and sentence alignment. - - Args: - source_text: The source text to be chunked - target_text: The target text to be chunked - max_chars: Maximum characters per chunk (default: 500) - - Returns: - List of TextChunk objects containing aligned paragraphs - """ - source_paragraphs = split_into_paragraphs(source_text) - target_paragraphs = split_into_paragraphs(target_text) - - # TODO: just accept Bertalign instance as arg - aligner = Bertalign( - "\n".join(source_paragraphs), - "\n".join(target_paragraphs), - is_split=True, - src_lang="zh", - tgt_lang="en", - ) - aligner.align_sents() - - if not isinstance(aligner.result, list): - raise AlignmentError("Invalid alignment result type") - - alignments: AlignmentResult = aligner.result - if not alignments: - raise AlignmentError("Aligner produced no results") - - # get para mappings and validate - source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs) - target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs) - - validate_alignment(aligner.result, source_sent_count, target_sent_count) - - return process_alignments( - aligner.result, - source_paragraphs, - target_paragraphs, - source_par_map, - target_par_map, - max_chars, - ) diff --git a/bertalign/corelib.py b/bertalign/corelib.py index 00f1539..a1baa14 100644 --- a/bertalign/corelib.py +++ b/bertalign/corelib.py @@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3): I: numpy array. Target index matrix of shape (num_src_sents, k). """ embedding_size = src_vecs.shape[1] - if torch.cuda.is_available() and platform == 'linux': # GPU version + if faiss.get_num_gpus() > 0 and platform == 'linux': # GPU version res = faiss.StandardGpuResources() index = faiss.IndexFlatIP(embedding_size) gpu_index = faiss.index_cpu_to_gpu(res, 0, index) diff --git a/bertalign/utils.py b/bertalign/utils.py index dc89a26..b085768 100644 --- a/bertalign/utils.py +++ b/bertalign/utils.py @@ -1,17 +1,25 @@ +from typing import Any + + import re from googletrans import Translator from sentence_splitter import SentenceSplitter -def clean_text(text): - clean_text = [] +def clean_line(line: str) -> str: + """Clean a single line of text.""" + line = line.strip() + if line: + line = re.sub(r'\s+', ' ', line) + return line + +def clean_text(text: str | list[str]) -> str | list[str]: + """Clean text or list of strings.""" + if isinstance(text, list): + return [clean_line(line) for line in text if clean_line(line)] + text = text.strip() lines = text.splitlines() - for line in lines: - line = line.strip() - if line: - line = re.sub('\s+', ' ', line) - clean_text.append(line) - return "\n".join(clean_text) + return [clean_line(line) for line in lines if clean_line(line)] def detect_lang(text): translator = Translator(service_urls=[ @@ -37,7 +45,7 @@ def split_sents(text, lang): raise Exception('The language {} is not suppored yet.'.format(LANG.ISO[lang])) def _split_zh(text, limit=1000): - sent_list = [] + sent_list: list[str] = [] text = re.sub('(?P([。?!](?![”’"\')])))', r'\g\n', text) text = re.sub('(?P([。?!]|…{1,2})[”’"\')])', r'\g\n', text) diff --git a/tests/chunk_test.py b/tests/chunk_test.py index 45eb9ed..d67fdd3 100644 --- a/tests/chunk_test.py +++ b/tests/chunk_test.py @@ -1,7 +1,18 @@ -from typing_extensions import override import unittest import os -from bertalign.chunk import create_aligned_chunks, TextChunk, AlignmentError +from typing_extensions import override +from bertalign.aligner import Bertalign +from bertalign.chunk import TextChunk, split_into_paragraphs + + +def print_chunks(chunks: list[TextChunk]) -> None: + for i, chunk in enumerate(chunks, 1): + print(f"\nChunk {i}:") + print(f"Source text ({len(chunk.source_text)} chars):") + print(chunk.source_text) + print(f"\nTarget text ({len(chunk.target_text)} chars):") + print(chunk.target_text) + print("-" * 80) class TestChunk(unittest.TestCase): @@ -28,22 +39,27 @@ class TestChunk(unittest.TestCase): self.target_text = f.read() def test_create_aligned_chunks(self): - chunks = create_aligned_chunks( - self.source_text, self.target_text, max_chars=500 - ) + src_ps = split_into_paragraphs(self.source_text) + tgt_ps = split_into_paragraphs(self.target_text) + aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en") + chunks = aligner.chunk(300) self.assertIsInstance(chunks, list) - self.assertTrue(all(isinstance(chunk, TextChunk) for chunk in chunks)) for chunk in chunks: self.assertIsInstance(chunk.source_text, str) self.assertIsInstance(chunk.target_text, str) + print_chunks(chunks) + self.assertGreater(len(chunks), 0) @unittest.skip("no") def test_create_aligned_chunks_empty_input(self): - c = create_aligned_chunks("", self.target_text) + src_ps = split_into_paragraphs(self.source_text) + tgt_ps = split_into_paragraphs(self.target_text) + aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en") + c = aligner.chunk() self.assertIsNone(c)