fix: faiss stuff, api
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,3 +2,5 @@ venv
|
||||
.vscode
|
||||
.pytest_cache/
|
||||
__pycache__/
|
||||
faiss
|
||||
|
||||
|
||||
@@ -2,48 +2,86 @@ import numpy as np
|
||||
|
||||
from bertalign import model
|
||||
from bertalign.corelib import *
|
||||
from bertalign.chunk import (
|
||||
AlignmentResult,
|
||||
AlignmentError,
|
||||
TextChunk,
|
||||
get_paragraph_indices,
|
||||
validate_alignment,
|
||||
process_alignments,
|
||||
)
|
||||
from bertalign.utils import *
|
||||
|
||||
|
||||
class Bertalign:
|
||||
def __init__(self,
|
||||
src,
|
||||
tgt,
|
||||
max_align=5,
|
||||
top_k=3,
|
||||
win=5,
|
||||
skip=-0.1,
|
||||
margin=True,
|
||||
len_penalty=True,
|
||||
is_split=False,
|
||||
src_lang=None,
|
||||
tgt_lang=None
|
||||
):
|
||||
|
||||
"""An automatic mulitlingual sentence aligner.
|
||||
|
||||
Uses sentence-transformers to align sentences across languages through semantic similarity.
|
||||
Performs two-step alignment: first finding 1-1 anchor points, then extracting more complex
|
||||
alignments (1-many, many-1, many-many) within constrained search paths."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
src,
|
||||
tgt,
|
||||
max_align=5,
|
||||
top_k=3,
|
||||
win=5,
|
||||
skip=-0.1,
|
||||
margin=True,
|
||||
len_penalty=True,
|
||||
is_split=False,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
):
|
||||
"""Initialize aligner with source and target texts.
|
||||
|
||||
Args:
|
||||
src (str | list[str]): Source text or list of sentences
|
||||
tgt (str | list[str]): Target text or list of sentences
|
||||
max_align (int, optional): Max sentences per alignment (default: 5)
|
||||
top_k (int, optional): Candidate targets per source (default: 3)
|
||||
win (int, optional): Window size for searching in second pass (default: 5)
|
||||
skip (float, optional): Skip penalty (default: -0.1)
|
||||
margin (bool, optional): Use margin in scoring (default: True)
|
||||
len_penalty (bool, optional): Apply length penalty (default: True)
|
||||
is_split (bool, optional): Input is pre-split into lines (default: False)
|
||||
src_lang (str, optional): Source language code (auto-detected if None)
|
||||
tgt_lang (str, optional): Target language code (auto-detected if None)
|
||||
"""
|
||||
|
||||
self.max_align = max_align
|
||||
self.top_k = top_k
|
||||
self.win = win
|
||||
self.skip = skip
|
||||
self.margin = margin
|
||||
self.len_penalty = len_penalty
|
||||
|
||||
src = clean_text(src)
|
||||
tgt = clean_text(tgt)
|
||||
|
||||
src_lang = src_lang if src_lang is not None else detect_lang(src)
|
||||
tgt_lang = tgt_lang if tgt_lang is not None else detect_lang(tgt)
|
||||
|
||||
if is_split:
|
||||
|
||||
# src = clean_text(src)
|
||||
# tgt = clean_text(tgt)
|
||||
|
||||
if isinstance(src, list) and isinstance(tgt, list):
|
||||
src_sents = src
|
||||
tgt_sents = tgt
|
||||
elif is_split:
|
||||
src_sents = src.splitlines()
|
||||
tgt_sents = tgt.splitlines()
|
||||
else:
|
||||
src_sents = split_sents(src, src_lang)
|
||||
tgt_sents = split_sents(tgt, tgt_lang)
|
||||
|
||||
|
||||
src_sents = clean_text(src_sents)
|
||||
tgt_sents = clean_text(tgt_sents)
|
||||
|
||||
src_num = len(src_sents)
|
||||
tgt_num = len(tgt_sents)
|
||||
|
||||
|
||||
src_lang = LANG.ISO[src_lang]
|
||||
tgt_lang = LANG.ISO[tgt_lang]
|
||||
|
||||
|
||||
print("Source language: {}, Number of sentences: {}".format(src_lang, src_num))
|
||||
print("Target language: {}, Number of sentences: {}".format(tgt_lang, tgt_num))
|
||||
|
||||
@@ -53,6 +91,8 @@ class Bertalign:
|
||||
|
||||
char_ratio = np.sum(src_lens[0,]) / np.sum(tgt_lens[0,])
|
||||
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.src_sents = src_sents
|
||||
@@ -64,36 +104,108 @@ class Bertalign:
|
||||
self.char_ratio = char_ratio
|
||||
self.src_vecs = src_vecs
|
||||
self.tgt_vecs = tgt_vecs
|
||||
|
||||
|
||||
def align_sents(self):
|
||||
"""Execute two-pass sentence alignment
|
||||
|
||||
1. First pass: finds 1-1 alignments as anchors
|
||||
2. Second pass: finds complex n:m alignments between anchors
|
||||
"""
|
||||
|
||||
print("Performing first-step alignment ...")
|
||||
D, I = find_top_k_sents(self.src_vecs[0,:], self.tgt_vecs[0,:], k=self.top_k)
|
||||
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
|
||||
D, I = find_top_k_sents(self.src_vecs[0, :], self.tgt_vecs[0, :], k=self.top_k)
|
||||
first_alignment_types = get_alignment_types(2) # 0-1, 1-0, 1-1
|
||||
first_w, first_path = find_first_search_path(self.src_num, self.tgt_num)
|
||||
first_pointers = first_pass_align(self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I)
|
||||
first_alignment = first_back_track(self.src_num, self.tgt_num, first_pointers, first_path, first_alignment_types)
|
||||
|
||||
first_pointers = first_pass_align(
|
||||
self.src_num, self.tgt_num, first_w, first_path, first_alignment_types, D, I
|
||||
)
|
||||
first_alignment = first_back_track(
|
||||
self.src_num,
|
||||
self.tgt_num,
|
||||
first_pointers,
|
||||
first_path,
|
||||
first_alignment_types,
|
||||
)
|
||||
|
||||
print("Performing second-step alignment ...")
|
||||
second_alignment_types = get_alignment_types(self.max_align)
|
||||
second_w, second_path = find_second_search_path(first_alignment, self.win, self.src_num, self.tgt_num)
|
||||
second_pointers = second_pass_align(self.src_vecs, self.tgt_vecs, self.src_lens, self.tgt_lens,
|
||||
second_w, second_path, second_alignment_types,
|
||||
self.char_ratio, self.skip, margin=self.margin, len_penalty=self.len_penalty)
|
||||
second_alignment = second_back_track(self.src_num, self.tgt_num, second_pointers, second_path, second_alignment_types)
|
||||
|
||||
print("Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(self.src_num, self.src_lang, self.tgt_num, self.tgt_lang))
|
||||
second_w, second_path = find_second_search_path(
|
||||
first_alignment, self.win, self.src_num, self.tgt_num
|
||||
)
|
||||
second_pointers = second_pass_align(
|
||||
self.src_vecs,
|
||||
self.tgt_vecs,
|
||||
self.src_lens,
|
||||
self.tgt_lens,
|
||||
second_w,
|
||||
second_path,
|
||||
second_alignment_types,
|
||||
self.char_ratio,
|
||||
self.skip,
|
||||
margin=self.margin,
|
||||
len_penalty=self.len_penalty,
|
||||
)
|
||||
second_alignment = second_back_track(
|
||||
self.src_num,
|
||||
self.tgt_num,
|
||||
second_pointers,
|
||||
second_path,
|
||||
second_alignment_types,
|
||||
)
|
||||
|
||||
print(
|
||||
"Finished! Successfully aligning {} {} sentences to {} {} sentences\n".format(
|
||||
self.src_num, self.src_lang, self.tgt_num, self.tgt_lang
|
||||
)
|
||||
)
|
||||
self.result = second_alignment
|
||||
|
||||
|
||||
def print_sents(self):
|
||||
for bead in (self.result):
|
||||
"""Print aligned sentences in parallel"""
|
||||
for bead in self.result:
|
||||
src_line = self._get_line(bead[0], self.src_sents)
|
||||
tgt_line = self._get_line(bead[1], self.tgt_sents)
|
||||
print(src_line + "\n" + tgt_line + "\n")
|
||||
|
||||
def chunk(self, max_chars: int = 512) -> list[TextChunk]:
|
||||
"""Create aligned chunks respecting paragraph and sentence boundaries
|
||||
|
||||
Args:
|
||||
max_chars (int, optional): Maximum characters per chunk. Defaults to 512.
|
||||
|
||||
Returns:
|
||||
list[TextChunk]: List of TextChunk with aligned paragraphs
|
||||
"""
|
||||
source_paragraphs = self.src_sents
|
||||
target_paragraphs = self.tgt_sents
|
||||
|
||||
self.align_sents()
|
||||
|
||||
if not isinstance(self.result, list):
|
||||
raise AlignmentError("Invalid alignment result type")
|
||||
|
||||
alignments: AlignmentResult = self.result
|
||||
if not alignments:
|
||||
raise AlignmentError("Aligner produced no results")
|
||||
|
||||
# get paragraph mappings and validate
|
||||
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
|
||||
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
|
||||
|
||||
validate_alignment(self.result, source_sent_count, target_sent_count)
|
||||
|
||||
return process_alignments(
|
||||
self.result,
|
||||
source_paragraphs,
|
||||
target_paragraphs,
|
||||
source_par_map,
|
||||
target_par_map,
|
||||
max_chars,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_line(bead, lines):
|
||||
line = ''
|
||||
line = ""
|
||||
if len(bead) > 0:
|
||||
line = ' '.join(lines[bead[0]:bead[-1]+1])
|
||||
line = " ".join(lines[bead[0] : bead[-1] + 1])
|
||||
return line
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from bertalign.aligner import Bertalign
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
@@ -158,52 +157,3 @@ def should_create_new_chunk(
|
||||
return at_paragraph_boundary(
|
||||
source_indices, source_par_map
|
||||
) and at_paragraph_boundary(target_indices, target_par_map)
|
||||
|
||||
|
||||
def create_aligned_chunks(
|
||||
source_text: str, target_text: str, max_chars: int = 500
|
||||
) -> list[TextChunk]:
|
||||
"""Create aligned chunks of text respecting paragraph and sentence alignment.
|
||||
|
||||
Args:
|
||||
source_text: The source text to be chunked
|
||||
target_text: The target text to be chunked
|
||||
max_chars: Maximum characters per chunk (default: 500)
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects containing aligned paragraphs
|
||||
"""
|
||||
source_paragraphs = split_into_paragraphs(source_text)
|
||||
target_paragraphs = split_into_paragraphs(target_text)
|
||||
|
||||
# TODO: just accept Bertalign instance as arg
|
||||
aligner = Bertalign(
|
||||
"\n".join(source_paragraphs),
|
||||
"\n".join(target_paragraphs),
|
||||
is_split=True,
|
||||
src_lang="zh",
|
||||
tgt_lang="en",
|
||||
)
|
||||
aligner.align_sents()
|
||||
|
||||
if not isinstance(aligner.result, list):
|
||||
raise AlignmentError("Invalid alignment result type")
|
||||
|
||||
alignments: AlignmentResult = aligner.result
|
||||
if not alignments:
|
||||
raise AlignmentError("Aligner produced no results")
|
||||
|
||||
# get para mappings and validate
|
||||
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
|
||||
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
|
||||
|
||||
validate_alignment(aligner.result, source_sent_count, target_sent_count)
|
||||
|
||||
return process_alignments(
|
||||
aligner.result,
|
||||
source_paragraphs,
|
||||
target_paragraphs,
|
||||
source_par_map,
|
||||
target_par_map,
|
||||
max_chars,
|
||||
)
|
||||
|
||||
@@ -389,7 +389,7 @@ def find_top_k_sents(src_vecs, tgt_vecs, k=3):
|
||||
I: numpy array. Target index matrix of shape (num_src_sents, k).
|
||||
"""
|
||||
embedding_size = src_vecs.shape[1]
|
||||
if torch.cuda.is_available() and platform == 'linux': # GPU version
|
||||
if faiss.get_num_gpus() > 0 and platform == 'linux': # GPU version
|
||||
res = faiss.StandardGpuResources()
|
||||
index = faiss.IndexFlatIP(embedding_size)
|
||||
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
import re
|
||||
from googletrans import Translator
|
||||
from sentence_splitter import SentenceSplitter
|
||||
|
||||
def clean_text(text):
|
||||
clean_text = []
|
||||
def clean_line(line: str) -> str:
|
||||
"""Clean a single line of text."""
|
||||
line = line.strip()
|
||||
if line:
|
||||
line = re.sub(r'\s+', ' ', line)
|
||||
return line
|
||||
|
||||
def clean_text(text: str | list[str]) -> str | list[str]:
|
||||
"""Clean text or list of strings."""
|
||||
if isinstance(text, list):
|
||||
return [clean_line(line) for line in text if clean_line(line)]
|
||||
|
||||
text = text.strip()
|
||||
lines = text.splitlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line:
|
||||
line = re.sub('\s+', ' ', line)
|
||||
clean_text.append(line)
|
||||
return "\n".join(clean_text)
|
||||
return [clean_line(line) for line in lines if clean_line(line)]
|
||||
|
||||
def detect_lang(text):
|
||||
translator = Translator(service_urls=[
|
||||
@@ -37,7 +45,7 @@ def split_sents(text, lang):
|
||||
raise Exception('The language {} is not suppored yet.'.format(LANG.ISO[lang]))
|
||||
|
||||
def _split_zh(text, limit=1000):
|
||||
sent_list = []
|
||||
sent_list: list[str] = []
|
||||
text = re.sub('(?P<quotation_mark>([。?!](?![”’"\')])))', r'\g<quotation_mark>\n', text)
|
||||
text = re.sub('(?P<quotation_mark>([。?!]|…{1,2})[”’"\')])', r'\g<quotation_mark>\n', text)
|
||||
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
from typing_extensions import override
|
||||
import unittest
|
||||
import os
|
||||
from bertalign.chunk import create_aligned_chunks, TextChunk, AlignmentError
|
||||
from typing_extensions import override
|
||||
from bertalign.aligner import Bertalign
|
||||
from bertalign.chunk import TextChunk, split_into_paragraphs
|
||||
|
||||
|
||||
def print_chunks(chunks: list[TextChunk]) -> None:
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
print(f"\nChunk {i}:")
|
||||
print(f"Source text ({len(chunk.source_text)} chars):")
|
||||
print(chunk.source_text)
|
||||
print(f"\nTarget text ({len(chunk.target_text)} chars):")
|
||||
print(chunk.target_text)
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
class TestChunk(unittest.TestCase):
|
||||
@@ -28,22 +39,27 @@ class TestChunk(unittest.TestCase):
|
||||
self.target_text = f.read()
|
||||
|
||||
def test_create_aligned_chunks(self):
|
||||
chunks = create_aligned_chunks(
|
||||
self.source_text, self.target_text, max_chars=500
|
||||
)
|
||||
src_ps = split_into_paragraphs(self.source_text)
|
||||
tgt_ps = split_into_paragraphs(self.target_text)
|
||||
aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
|
||||
chunks = aligner.chunk(300)
|
||||
|
||||
self.assertIsInstance(chunks, list)
|
||||
self.assertTrue(all(isinstance(chunk, TextChunk) for chunk in chunks))
|
||||
|
||||
for chunk in chunks:
|
||||
self.assertIsInstance(chunk.source_text, str)
|
||||
self.assertIsInstance(chunk.target_text, str)
|
||||
|
||||
print_chunks(chunks)
|
||||
|
||||
self.assertGreater(len(chunks), 0)
|
||||
|
||||
@unittest.skip("no")
|
||||
def test_create_aligned_chunks_empty_input(self):
|
||||
c = create_aligned_chunks("", self.target_text)
|
||||
src_ps = split_into_paragraphs(self.source_text)
|
||||
tgt_ps = split_into_paragraphs(self.target_text)
|
||||
aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
|
||||
c = aligner.chunk()
|
||||
self.assertIsNone(c)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user