68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
import unittest
|
|
import os
|
|
from typing_extensions import override
|
|
from bertalign.aligner import Bertalign
|
|
from bertalign.chunk import TextChunk, split_into_paragraphs
|
|
|
|
|
|
def print_chunks(chunks: list[TextChunk]) -> None:
|
|
for i, chunk in enumerate(chunks, 1):
|
|
print(f"\nChunk {i}:")
|
|
print(f"Source text ({len(chunk.source_text)} chars):")
|
|
print(chunk.source_text)
|
|
print(f"\nTarget text ({len(chunk.target_text)} chars):")
|
|
print(chunk.target_text)
|
|
print("-" * 80)
|
|
|
|
|
|
class TestChunk(unittest.TestCase):
|
|
data_dir: str
|
|
source_file: str
|
|
target_file: str
|
|
source_text: str = ""
|
|
target_text: str = ""
|
|
|
|
def __init__(self, methodName: str = "runTest") -> None:
|
|
super().__init__(methodName)
|
|
|
|
self.data_dir = os.path.join(os.path.dirname(__file__), "data")
|
|
self.source_file = os.path.join(self.data_dir, "ri_4.zh") # Source text file
|
|
self.target_file = os.path.join(self.data_dir, "ri_4.en") # Target text file
|
|
|
|
@override
|
|
def setUp(self):
|
|
# Load source and target text from files
|
|
with open(self.source_file, "r", encoding="utf-8") as f:
|
|
self.source_text = f.read()
|
|
|
|
with open(self.target_file, "r", encoding="utf-8") as f:
|
|
self.target_text = f.read()
|
|
|
|
def test_create_aligned_chunks(self):
|
|
src_ps = split_into_paragraphs(self.source_text)
|
|
tgt_ps = split_into_paragraphs(self.target_text)
|
|
aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
|
|
chunks = aligner.chunk(300)
|
|
|
|
self.assertIsInstance(chunks, list)
|
|
|
|
for chunk in chunks:
|
|
self.assertIsInstance(chunk.source_text, str)
|
|
self.assertIsInstance(chunk.target_text, str)
|
|
|
|
print_chunks(chunks)
|
|
|
|
self.assertGreater(len(chunks), 0)
|
|
|
|
@unittest.skip("no")
|
|
def test_create_aligned_chunks_empty_input(self):
|
|
src_ps = split_into_paragraphs(self.source_text)
|
|
tgt_ps = split_into_paragraphs(self.target_text)
|
|
aligner = Bertalign(src_ps, tgt_ps, src_lang="zh", tgt_lang="en")
|
|
c = aligner.chunk()
|
|
self.assertIsNone(c)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_ = unittest.main()
|