210 lines
6.4 KiB
Python
210 lines
6.4 KiB
Python
from bertalign.aligner import Bertalign
|
|
from dataclasses import dataclass
|
|
from typing import TypeAlias
|
|
|
|
AlignmentResult: TypeAlias = list[tuple[list[int], list[int]]]
|
|
|
|
|
|
@dataclass
|
|
class TextChunk:
|
|
source_paragraphs: list[str]
|
|
target_paragraphs: list[str]
|
|
source_text: str = ""
|
|
target_text: str = ""
|
|
|
|
|
|
class AlignmentError(Exception):
|
|
pass
|
|
|
|
|
|
def split_into_paragraphs(text: str) -> list[str]:
|
|
return [p.strip() for p in text.split("\n\n") if p.strip()]
|
|
|
|
|
|
def get_paragraph_indices(text_paragraphs: list[str]) -> tuple[dict[int, int], int]:
|
|
paragraph_map: dict[int, int] = {}
|
|
total_sentences = 0
|
|
|
|
for i, paragraph in enumerate(text_paragraphs):
|
|
num_sentences = len(paragraph.split("\n"))
|
|
paragraph_map.update({total_sentences + j: i for j in range(num_sentences)})
|
|
total_sentences += num_sentences
|
|
|
|
return paragraph_map, total_sentences
|
|
|
|
|
|
def validate_alignment(
|
|
alignments: list[tuple[list[int], list[int]]], source_count: int, target_count: int
|
|
) -> None:
|
|
if not alignments:
|
|
raise AlignmentError("Aligner produced no results")
|
|
|
|
max_source = max(idx for alignment in alignments for idx in alignment[0])
|
|
max_target = max(idx for alignment in alignments for idx in alignment[1])
|
|
|
|
if max_source >= source_count or max_target >= target_count:
|
|
raise AlignmentError("Aligner produced invalid sentence indices")
|
|
|
|
|
|
def validate_chunks(
|
|
chunks: list[TextChunk], source_pars: list[str], target_pars: list[str]
|
|
) -> None:
|
|
chunk_source_pars = [par for chunk in chunks for par in chunk.source_paragraphs]
|
|
chunk_target_pars = [par for chunk in chunks for par in chunk.target_paragraphs]
|
|
|
|
def validate_pars(orig: list[str], chunk: list[str], name: str) -> None:
|
|
if len(orig) != len(chunk):
|
|
raise AlignmentError(f"{name} paragraph count mismatch")
|
|
if not all(o == c for o, c in zip(orig, chunk)):
|
|
raise AlignmentError(f"{name} paragraph content mismatch")
|
|
|
|
validate_pars(source_pars, chunk_source_pars, "Source")
|
|
validate_pars(target_pars, chunk_target_pars, "Target")
|
|
|
|
|
|
def process_alignments(
|
|
alignments: list[tuple[list[int], list[int]]],
|
|
source_paragraphs: list[str],
|
|
target_paragraphs: list[str],
|
|
source_par_map: dict[int, int],
|
|
target_par_map: dict[int, int],
|
|
max_chars: int,
|
|
) -> list[TextChunk]:
|
|
"""Process alignments into chunks."""
|
|
chunks: list[TextChunk] = []
|
|
current = TextChunk([], [])
|
|
processed_source: set[int] = set()
|
|
processed_target: set[int] = set()
|
|
|
|
for source_indices, target_indices in alignments:
|
|
# add para to current chunk
|
|
if source_indices:
|
|
source_range = get_paragraph_range(source_indices, source_par_map)
|
|
add_paragraphs(
|
|
current.source_paragraphs,
|
|
source_paragraphs,
|
|
source_range,
|
|
processed_source,
|
|
)
|
|
|
|
if target_indices:
|
|
target_range = get_paragraph_range(target_indices, target_par_map)
|
|
add_paragraphs(
|
|
current.target_paragraphs,
|
|
target_paragraphs,
|
|
target_range,
|
|
processed_target,
|
|
)
|
|
|
|
# upd chunk texts
|
|
current.source_text = "\n\n".join(current.source_paragraphs)
|
|
current.target_text = "\n\n".join(current.target_paragraphs)
|
|
|
|
if should_create_new_chunk(
|
|
current,
|
|
max_chars,
|
|
source_indices,
|
|
target_indices,
|
|
source_par_map,
|
|
target_par_map,
|
|
):
|
|
chunks.append(current)
|
|
current = TextChunk([], [])
|
|
processed_source.clear()
|
|
processed_target.clear()
|
|
|
|
if current.source_paragraphs:
|
|
chunks.append(current)
|
|
|
|
validate_chunks(chunks, source_paragraphs, target_paragraphs)
|
|
return chunks
|
|
|
|
|
|
def get_paragraph_range(indices: list[int], par_map: dict[int, int]) -> range:
|
|
return range(par_map[indices[0]], par_map[indices[-1]] + 1)
|
|
|
|
|
|
def add_paragraphs(
|
|
chunk_paragraphs: list[str],
|
|
paragraphs: list[str],
|
|
par_range: range,
|
|
processed: set[int],
|
|
) -> None:
|
|
for idx in par_range:
|
|
if idx not in processed:
|
|
chunk_paragraphs.append(paragraphs[idx])
|
|
processed.add(idx)
|
|
|
|
|
|
def should_create_new_chunk(
|
|
chunk: TextChunk,
|
|
max_chars: int,
|
|
source_indices: list[int],
|
|
target_indices: list[int],
|
|
source_par_map: dict[int, int],
|
|
target_par_map: dict[int, int],
|
|
) -> bool:
|
|
if len(chunk.source_text) <= max_chars and len(chunk.target_text) <= max_chars:
|
|
return False
|
|
|
|
def at_paragraph_boundary(indices: list[int], par_map: dict[int, int]) -> bool:
|
|
if not indices:
|
|
return True
|
|
return all(
|
|
idx == max(i for i, par_idx in par_map.items() if par_idx == par_map[idx])
|
|
for idx in indices
|
|
)
|
|
|
|
return at_paragraph_boundary(
|
|
source_indices, source_par_map
|
|
) and at_paragraph_boundary(target_indices, target_par_map)
|
|
|
|
|
|
def create_aligned_chunks(
|
|
source_text: str, target_text: str, max_chars: int = 500
|
|
) -> list[TextChunk]:
|
|
"""Create aligned chunks of text respecting paragraph and sentence alignment.
|
|
|
|
Args:
|
|
source_text: The source text to be chunked
|
|
target_text: The target text to be chunked
|
|
max_chars: Maximum characters per chunk (default: 500)
|
|
|
|
Returns:
|
|
List of TextChunk objects containing aligned paragraphs
|
|
"""
|
|
source_paragraphs = split_into_paragraphs(source_text)
|
|
target_paragraphs = split_into_paragraphs(target_text)
|
|
|
|
# TODO: just accept Bertalign instance as arg
|
|
aligner = Bertalign(
|
|
"\n".join(source_paragraphs),
|
|
"\n".join(target_paragraphs),
|
|
is_split=True,
|
|
src_lang="zh",
|
|
tgt_lang="en",
|
|
)
|
|
aligner.align_sents()
|
|
|
|
if not isinstance(aligner.result, list):
|
|
raise AlignmentError("Invalid alignment result type")
|
|
|
|
alignments: AlignmentResult = aligner.result
|
|
if not alignments:
|
|
raise AlignmentError("Aligner produced no results")
|
|
|
|
# get para mappings and validate
|
|
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
|
|
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
|
|
|
|
validate_alignment(aligner.result, source_sent_count, target_sent_count)
|
|
|
|
return process_alignments(
|
|
aligner.result,
|
|
source_paragraphs,
|
|
target_paragraphs,
|
|
source_par_map,
|
|
target_par_map,
|
|
max_chars,
|
|
)
|