fix: faiss stuff, api

This commit is contained in:
2025-02-18 23:05:35 +06:00
parent 5d4149be63
commit 875e189aab
6 changed files with 194 additions and 106 deletions

2
.gitignore vendored
View File

@@ -2,3 +2,5 @@ venv
.vscode .vscode
.pytest_cache/ .pytest_cache/
__pycache__/ __pycache__/
faiss

View File

@@ -2,48 +2,86 @@ import numpy as np
from bertalign import model from bertalign import model
from bertalign.corelib import * from bertalign.corelib import *
from bertalign.chunk import (
AlignmentResult,
AlignmentError,
TextChunk,
get_paragraph_indices,
validate_alignment,
process_alignments,
)
from bertalign.utils import * from bertalign.utils import *
class Bertalign: class Bertalign:
def __init__(self, """An automatic mulitlingual sentence aligner.
src,
tgt, Uses sentence-transformers to align sentences across languages through semantic similarity.
max_align=5, Performs two-step alignment: first finding 1-1 anchor points, then extracting more complex
top_k=3, alignments (1-many, many-1, many-many) within constrained search paths."""
win=5,
skip=-0.1, def __init__(
margin=True, self,
len_penalty=True, src,
is_split=False, tgt,
src_lang=None, max_align=5,
tgt_lang=None top_k=3,
): win=5,
skip=-0.1,
margin=True,
len_penalty=True,
is_split=False,
src_lang=None,
tgt_lang=None,
):
"""Initialize aligner with source and target texts.
Args:
src (str | list[str]): Source text or list of sentences
tgt (str | list[str]): Target text or list of sentences
max_align (int, optional): Max sentences per alignment (default: 5)
top_k (int, optional): Candidate targets per source (default: 3)
win (int, optional): Window size for searching in second pass (default: 5)
skip (float, optional): Skip penalty (default: -0.1)
margin (bool, optional): Use margin in scoring (default: True)
len_penalty (bool, optional): Apply length penalty (default: True)
is_split (bool, optional): Input is pre-split into lines (default: False)
src_lang (str, optional): Source language code (auto-detected if None)
tgt_lang (str, optional): Target language code (auto-detected if None)
"""
self.max_align = max_align self.max_align = max_align
self.top_k = top_k self.top_k = top_k
self.win = win self.win = win
self.skip = skip self.skip = skip
self.margin = margin self.margin = margin
self.len_penalty = len_penalty self.len_penalty = len_penalty
src = clean_text(src)
tgt = clean_text(tgt)
src_lang = src_lang if src_lang is not None else detect_lang(src) src_lang = src_lang if src_lang is not None else detect_lang(src)
tgt_lang = tgt_lang if tgt_lang is not None else detect_lang(tgt) tgt_lang = tgt_lang if tgt_lang is not None else detect_lang(tgt)
if is_split: # src = clean_text(src)
# tgt = clean_text(tgt)
if isinstance(src, list) and isinstance(tgt, list):
src_sents = src
tgt_sents = tgt
elif is_split:
src_sents = src.splitlines() src_sents = src.splitlines()
tgt_sents = tgt.splitlines() tgt_sents = tgt.splitlines()
else: else:
src_sents = split_sents(src, src_lang) src_sents = split_sents(src, src_lang)
tgt_sents = split_sents(tgt, tgt_lang) tgt_sents = split_sents(tgt, tgt_lang)
src_sents = clean_text(src_sents)
tgt_sents = clean_text(tgt_sents)
src_num = len(src_sents) src_num = len(src_sents)
tgt_num = len(tgt_sents) tgt_num = len(tgt_sents)
src_lang = LANG.ISO[src_lang] src_lang = LANG.ISO[src_lang]
tgt_lang = LANG.ISO[tgt_lang] tgt_lang = LANG.ISO[tgt_lang]
print("Source language: {}, Number of sentences: {}".format(src_lang, src_num)) print("Source language: {}, Number of sentences: {}".format(src_lang, src_num))
print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num)) print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num))
@@ -53,6 +91,8 @@ class Bertalign:
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,]) char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
self.src = src
self.tgt = tgt
self.src_lang = src_lang self.src_lang = src_lang
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
self.src_sents = src_sents self.src_sents = src_sents
@@ -64,36 +104,108 @@ class Bertalign:
self.char_ratio = char_ratio self.char_ratio = char_ratio
self.src_vecs = src_vecs self.src_vecs = src_vecs
self.tgt_vecs = tgt_vecs self.tgt_vecs = tgt_vecs
def align_sents(self): def align_sents(self):
"""Execute two-pass sentence alignment
1. First pass: finds 1-1 alignments as anchors
2. Second pass: finds complex n:m alignments between anchors
"""
print("Performing first-step alignment ...") print("Performing first-step alignment ...")
D, I = find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k) D, I = find_top_k_sents(self.src_vecs[0, :], self.tgt_vecs[0, :], k=self.top_k)
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1 first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
first_w, first_path = find_first_search_path(self.src_num, self.tgt_num) first_w, first_path = find_first_search_path(self.src_num, self.tgt_num)
first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I) first_pointers = first_pass_align(
first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types) self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I
)
first_alignment = first_back_track(
self.src_num,
self.tgt_num,
first_pointers,
first_path,
first_alignment_types,
)
print("Performing second-step alignment ...") print("Performing second-step alignment ...")
second_alignment_types = get_alignment_types(self.max_align) second_alignment_types = get_alignment_types(self.max_align)
second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num) second_w, second_path = find_second_search_path(
second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens, first_alignment, self.win, self.src_num, self.tgt_num
second_w, second_path, second_alignment_types, )
self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty) second_pointers = second_pass_align(
second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types) self.src_vecs,
self.tgt_vecs,
print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang)) self.src_lens,
self.tgt_lens,
second_w,
second_path,
second_alignment_types,
self.char_ratio,
self.skip,
margin=self.margin,
len_penalty=self.len_penalty,
)
second_alignment = second_back_track(
self.src_num,
self.tgt_num,
second_pointers,
second_path,
second_alignment_types,
)
print(
"Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(
self.src_num, self.src_lang, self.tgt_num, self.tgt_lang
)
)
self.result = second_alignment self.result = second_alignment
def print_sents(self): def print_sents(self):
for bead in (self.result): """Print aligned sentences in parallel"""
for bead in self.result:
src_line = self._get_line(bead[0], self.src_sents) src_line = self._get_line(bead[0], self.src_sents)
tgt_line = self._get_line(bead[1], self.tgt_sents) tgt_line = self._get_line(bead[1], self.tgt_sents)
print(src_line + "\n" + tgt_line + "\n") print(src_line + "\n" + tgt_line + "\n")
def chunk(self, max_chars: int = 512) -> list[TextChunk]:
"""Create aligned chunks respecting paragraph and sentence boundaries
Args:
max_chars (int, optional): Maximum characters per chunk. Defaults to 512.
Returns:
list[TextChunk]: List of TextChunk with aligned paragraphs
"""
source_paragraphs = self.src_sents
target_paragraphs = self.tgt_sents
self.align_sents()
if not isinstance(self.result, list):
raise AlignmentError("Invalid alignment result type")
alignments: AlignmentResult = self.result
if not alignments:
raise AlignmentError("Aligner produced no results")
# get paragraph mappings and validate
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
validate_alignment(self.result, source_sent_count, target_sent_count)
return process_alignments(
self.result,
source_paragraphs,
target_paragraphs,
source_par_map,
target_par_map,
max_chars,
)
@staticmethod @staticmethod
def _get_line(bead, lines): def _get_line(bead, lines):
line = '' line = ""
if len(bead) > 0: if len(bead) > 0:
line = ' '.join(lines[bead[0]:bead[-1]+1]) line = " ".join(lines[bead[0] : bead[-1] + 1])
return line return line

View File

@@ -1,4 +1,3 @@
from bertalign.aligner import Bertalign
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeAlias from typing import TypeAlias
@@ -158,52 +157,3 @@ def should_create_new_chunk(
return at_paragraph_boundary( return at_paragraph_boundary(
source_indices, source_par_map source_indices, source_par_map
) and at_paragraph_boundary(target_indices, target_par_map) ) and at_paragraph_boundary(target_indices, target_par_map)
def create_aligned_chunks(
source_text: str, target_text: str, max_chars: int = 500
) -> list[TextChunk]:
"""Create aligned chunks of text respecting paragraph and sentence alignment.
Args:
source_text: The source text to be chunked
target_text: The target text to be chunked
max_chars: Maximum characters per chunk (default: 500)
Returns:
List of TextChunk objects containing aligned paragraphs
"""
source_paragraphs = split_into_paragraphs(source_text)
target_paragraphs = split_into_paragraphs(target_text)
# TODO: just accept Bertalign instance as arg
aligner = Bertalign(
"\n".join(source_paragraphs),
"\n".join(target_paragraphs),
is_split=True,
src_lang="zh",
tgt_lang="en",
)
aligner.align_sents()
if not isinstance(aligner.result, list):
raise AlignmentError("Invalid alignment result type")
alignments: AlignmentResult = aligner.result
if not alignments:
raise AlignmentError("Aligner produced no results")
# get para mappings and validate
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
validate_alignment(aligner.result, source_sent_count, target_sent_count)
return process_alignments(
aligner.result,
source_paragraphs,
target_paragraphs,
source_par_map,
target_par_map,
max_chars,
)

View File

@@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3):
I: numpy array. Target index matrix of shape (num_src_sents, k). I: numpy array. Target index matrix of shape (num_src_sents, k).
""" """
embedding_size = src_vecs.shape[1] embedding_size = src_vecs.shape[1]
if torch.cuda.is_available() and platform == 'linux': # GPU version if faiss.get_num_gpus() > 0 and platform == 'linux': # GPU version
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
index = faiss.IndexFlatIP(embedding_size) index = faiss.IndexFlatIP(embedding_size)
gpu_index = faiss.index_cpu_to_gpu(res, 0, index) gpu_index = faiss.index_cpu_to_gpu(res, 0, index)

View File

@@ -1,17 +1,25 @@
from typing import Any
import re import re
from googletrans import Translator from googletrans import Translator
from sentence_splitter import SentenceSplitter from sentence_splitter import SentenceSplitter
def clean_text(text): def clean_line(line: str) -> str:
clean_text = [] """Clean a single line of text."""
line = line.strip()
if line:
line = re.sub(r'\s+', ' ', line)
return line
def clean_text(text: str | list[str]) -> str | list[str]:
"""Clean text or list of strings."""
if isinstance(text, list):
return [clean_line(line) for line in text if clean_line(line)]
text = text.strip() text = text.strip()
lines = text.splitlines() lines = text.splitlines()
for line in lines: return [clean_line(line) for line in lines if clean_line(line)]
line = line.strip()
if line:
line = re.sub('\s+', ' ', line)
clean_text.append(line)
return "\n".join(clean_text)
def detect_lang(text): def detect_lang(text):
translator = Translator(service_urls=[ translator = Translator(service_urls=[
@@ -37,7 +45,7 @@ def split_sents(text, lang):
raise Exception('The language {} is not suppored yet.'.format(LANG.ISO[lang])) raise Exception('The language {} is not suppored yet.'.format(LANG.ISO[lang]))
def _split_zh(text, limit=1000): def _split_zh(text, limit=1000):
sent_list = [] sent_list: list[str] = []
text = re.sub('(?P<quotation_mark>([。?!](?![”’"\'])))', r'\g<quotation_mark>\n', text) text = re.sub('(?P<quotation_mark>([。?!](?![”’"\'])))', r'\g<quotation_mark>\n', text)
text = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', text) text = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\'])', r'\g<quotation_mark>\n', text)

View File

@@ -1,7 +1,18 @@
from typing_extensions import override
import unittest import unittest
import os import os
from bertalign.chunk import create_aligned_chunks, TextChunk, AlignmentError from typing_extensions import override
from bertalign.aligner import Bertalign
from bertalign.chunk import TextChunk, split_into_paragraphs
def print_chunks(chunks: list[TextChunk]) -> None:
for i, chunk in enumerate(chunks, 1):
print(f"\nChunk {i}:")
print(f"Source text ({len(chunk.source_text)} chars):")
print(chunk.source_text)
print(f"\nTarget text ({len(chunk.target_text)} chars):")
print(chunk.target_text)
print("-" * 80)
class TestChunk(unittest.TestCase): class TestChunk(unittest.TestCase):
@@ -28,22 +39,27 @@ class TestChunk(unittest.TestCase):
self.target_text = f.read() self.target_text = f.read()
def test_create_aligned_chunks(self): def test_create_aligned_chunks(self):
chunks = create_aligned_chunks( src_ps = split_into_paragraphs(self.source_text)
self.source_text, self.target_text, max_chars=500 tgt_ps = split_into_paragraphs(self.target_text)
) aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
chunks = aligner.chunk(300)
self.assertIsInstance(chunks, list) self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, TextChunk) for chunk in chunks))
for chunk in chunks: for chunk in chunks:
self.assertIsInstance(chunk.source_text, str) self.assertIsInstance(chunk.source_text, str)
self.assertIsInstance(chunk.target_text, str) self.assertIsInstance(chunk.target_text, str)
print_chunks(chunks)
self.assertGreater(len(chunks), 0) self.assertGreater(len(chunks), 0)
@unittest.skip("no") @unittest.skip("no")
def test_create_aligned_chunks_empty_input(self): def test_create_aligned_chunks_empty_input(self):
c = create_aligned_chunks("", self.target_text) src_ps = split_into_paragraphs(self.source_text)
tgt_ps = split_into_paragraphs(self.target_text)
aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
c = aligner.chunk()
self.assertIsNone(c) self.assertIsNone(c)