Files
bertalign/bertalign/chunk.py
2025-02-18 03:39:44 +06:00

210 lines
6.4 KiB
Python

from bertalign.aligner import Bertalign
from dataclasses import dataclass
from typing import TypeAlias
AlignmentResult: TypeAlias = list[tuple[list[int], list[int]]]
@dataclass
class TextChunk:
source_paragraphs: list[str]
target_paragraphs: list[str]
source_text: str = ""
target_text: str = ""
class AlignmentError(Exception):
pass
def split_into_paragraphs(text: str) -> list[str]:
return [p.strip() for p in text.split("\n\n") if p.strip()]
def get_paragraph_indices(text_paragraphs: list[str]) -> tuple[dict[int, int], int]:
paragraph_map: dict[int, int] = {}
total_sentences = 0
for i, paragraph in enumerate(text_paragraphs):
num_sentences = len(paragraph.split("\n"))
paragraph_map.update({total_sentences + j: i for j in range(num_sentences)})
total_sentences += num_sentences
return paragraph_map, total_sentences
def validate_alignment(
alignments: list[tuple[list[int], list[int]]], source_count: int, target_count: int
) -> None:
if not alignments:
raise AlignmentError("Aligner produced no results")
max_source = max(idx for alignment in alignments for idx in alignment[0])
max_target = max(idx for alignment in alignments for idx in alignment[1])
if max_source >= source_count or max_target >= target_count:
raise AlignmentError("Aligner produced invalid sentence indices")
def validate_chunks(
chunks: list[TextChunk], source_pars: list[str], target_pars: list[str]
) -> None:
chunk_source_pars = [par for chunk in chunks for par in chunk.source_paragraphs]
chunk_target_pars = [par for chunk in chunks for par in chunk.target_paragraphs]
def validate_pars(orig: list[str], chunk: list[str], name: str) -> None:
if len(orig) != len(chunk):
raise AlignmentError(f"{name} paragraph count mismatch")
if not all(o == c for o, c in zip(orig, chunk)):
raise AlignmentError(f"{name} paragraph content mismatch")
validate_pars(source_pars, chunk_source_pars, "Source")
validate_pars(target_pars, chunk_target_pars, "Target")
def process_alignments(
alignments: list[tuple[list[int], list[int]]],
source_paragraphs: list[str],
target_paragraphs: list[str],
source_par_map: dict[int, int],
target_par_map: dict[int, int],
max_chars: int,
) -> list[TextChunk]:
"""Process alignments into chunks."""
chunks: list[TextChunk] = []
current = TextChunk([], [])
processed_source: set[int] = set()
processed_target: set[int] = set()
for source_indices, target_indices in alignments:
# add para to current chunk
if source_indices:
source_range = get_paragraph_range(source_indices, source_par_map)
add_paragraphs(
current.source_paragraphs,
source_paragraphs,
source_range,
processed_source,
)
if target_indices:
target_range = get_paragraph_range(target_indices, target_par_map)
add_paragraphs(
current.target_paragraphs,
target_paragraphs,
target_range,
processed_target,
)
# upd chunk texts
current.source_text = "\n\n".join(current.source_paragraphs)
current.target_text = "\n\n".join(current.target_paragraphs)
if should_create_new_chunk(
current,
max_chars,
source_indices,
target_indices,
source_par_map,
target_par_map,
):
chunks.append(current)
current = TextChunk([], [])
processed_source.clear()
processed_target.clear()
if current.source_paragraphs:
chunks.append(current)
validate_chunks(chunks, source_paragraphs, target_paragraphs)
return chunks
def get_paragraph_range(indices: list[int], par_map: dict[int, int]) -> range:
return range(par_map[indices[0]], par_map[indices[-1]] + 1)
def add_paragraphs(
chunk_paragraphs: list[str],
paragraphs: list[str],
par_range: range,
processed: set[int],
) -> None:
for idx in par_range:
if idx not in processed:
chunk_paragraphs.append(paragraphs[idx])
processed.add(idx)
def should_create_new_chunk(
chunk: TextChunk,
max_chars: int,
source_indices: list[int],
target_indices: list[int],
source_par_map: dict[int, int],
target_par_map: dict[int, int],
) -> bool:
if len(chunk.source_text) <= max_chars and len(chunk.target_text) <= max_chars:
return False
def at_paragraph_boundary(indices: list[int], par_map: dict[int, int]) -> bool:
if not indices:
return True
return all(
idx == max(i for i, par_idx in par_map.items() if par_idx == par_map[idx])
for idx in indices
)
return at_paragraph_boundary(
source_indices, source_par_map
) and at_paragraph_boundary(target_indices, target_par_map)
def create_aligned_chunks(
source_text: str, target_text: str, max_chars: int = 500
) -> list[TextChunk]:
"""Create aligned chunks of text respecting paragraph and sentence alignment.
Args:
source_text: The source text to be chunked
target_text: The target text to be chunked
max_chars: Maximum characters per chunk (default: 500)
Returns:
List of TextChunk objects containing aligned paragraphs
"""
source_paragraphs = split_into_paragraphs(source_text)
target_paragraphs = split_into_paragraphs(target_text)
# TODO: just accept Bertalign instance as arg
aligner = Bertalign(
"\n".join(source_paragraphs),
"\n".join(target_paragraphs),
is_split=True,
src_lang="zh",
tgt_lang="en",
)
aligner.align_sents()
if not isinstance(aligner.result, list):
raise AlignmentError("Invalid alignment result type")
alignments: AlignmentResult = aligner.result
if not alignments:
raise AlignmentError("Aligner produced no results")
# get para mappings and validate
source_par_map, source_sent_count = get_paragraph_indices(source_paragraphs)
target_par_map, target_sent_count = get_paragraph_indices(target_paragraphs)
validate_alignment(aligner.result, source_sent_count, target_sent_count)
return process_alignments(
aligner.result,
source_paragraphs,
target_paragraphs,
source_par_map,
target_par_map,
max_chars,
)