Files
zh-en-wn-dataset/paragraph_ctx_collect.py
2025-02-09 14:30:25 +06:00

250 lines
7.6 KiB
Python

from typing import List, Tuple
import sqlite3
import re
def get_chapter_paragraphs(
cursor: sqlite3.Cursor, book_id: str, chapter_id: str
) -> Tuple[List[str], List[str]]:
"""
Gets all paragraphs for a specific chapter.
Returns (english_paragraphs, chinese_paragraphs).
"""
cursor.execute(
"""
select text_en, text_zh
from paragraphs
where book_id = ? and chapter_id = ?
""",
(book_id, chapter_id),
)
en_texts = []
zh_texts = []
for en, zh in cursor.fetchall():
if en and zh: # Skip empty paragraphs
en_texts.append(en.strip())
zh_texts.append(zh.strip())
return en_texts, zh_texts
def get_text_state(text: str) -> tuple[int, bool, bool]:
"""
Analyzes text for continuity markers
Returns (bracket_change, ends_with_colon, incomplete_sentence)
Args:
text: String to analyze
Returns:
tuple containing:
- int: Net change in bracket depth (positive for unclosed, negative for extra closing)
- bool: Whether the text ends with a colon
- bool: Whether the text ends without proper sentence termination
"""
if not text:
return 0, False, False
# count bracket balance
opens = len(re.findall(r"[【「『]", text))
closes = len(re.findall(r"[】」』]", text))
ends_with_punct = bool(re.search(r"[.!?。!?]\s*$", text.rstrip()))
return (opens - closes, text.rstrip().endswith(":"), not ends_with_punct)
def create_chunks(
en_texts: List[str],
zh_texts: List[str],
target_size: int = 1024,
min_size: int = 512,
max_size: int = 2048,
) -> List[Tuple[str, str]]:
"""
Creates parallel text chunks respecting continuity markers and size constraints
Args:
en_texts: List of English text paragraphs
zh_texts: List of corresponding Chinese text paragraphs
target_size: Ideal size for each chunk in characters
min_size: Minimum acceptable chunk size
max_size: Maximum acceptable chunk size
Returns:
List of tuples containing (english_chunk, chinese_chunk)
"""
chunks = []
current_en = []
current_zh = []
current_chars = 0
bracket_depth = 0
i = 0
while i < len(en_texts):
current_text = en_texts[i]
para_chars = len(current_text)
bracket_change, ends_with_colon, incomplete_sentence = get_text_state(
current_text
)
bracket_depth += bracket_change
# check if adding would exceed max_size
if current_chars + para_chars > max_size:
# only split if we're not in brackets, sentence is complete, and have met min_size
if (
bracket_depth <= 0
and not incomplete_sentence
and current_chars >= min_size
):
chunks.append(("\n\n".join(current_en), "\n\n".join(current_zh)))
current_en = []
current_zh = []
current_chars = 0
# add cur p
current_en.append(current_text)
current_zh.append(zh_texts[i])
current_chars += para_chars
# can we create a chunk?
next_exists = i + 1 < len(en_texts)
if (
current_chars >= target_size
and bracket_depth <= 0
and not ends_with_colon
and not incomplete_sentence
and next_exists
):
chunks.append(("\n\n".join(current_en), "\n\n".join(current_zh)))
current_en = []
current_zh = []
current_chars = 0
bracket_depth = 0
i += 1
# add remaining text if it it's min_size
if current_chars >= min_size:
chunks.append(("\n\n".join(current_en), "\n\n".join(current_zh)))
return chunks
def create_chunk_table(cursor: sqlite3.Cursor):
"""Creates the paragraph_chunks table if it doesn't exist"""
cursor.execute(
"""
create table if not exists paragraph_chunks (
id integer primary key autoincrement,
book_id text not null,
chapter_id text not null,
chunk_index integer not null,
text_en text,
text_zh text,
char_count integer,
foreign key (book_id, chapter_id) references chapters(book_id, chapter_id),
unique(book_id, chapter_id, chunk_index)
)
"""
)
def store_book_chunks(db_path: str, book_id: str):
"""Process a book and store its chunks in the database"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
create_chunk_table(cursor)
chunks_by_chapter = process_book(db_path, book_id)
for chapter_id, chapter_chunks in chunks_by_chapter:
for i, (en_chunk, zh_chunk) in enumerate(chapter_chunks):
cursor.execute(
"""
insert into paragraph_chunks
(book_id, chapter_id, chunk_index, text_en, text_zh, char_count)
values (?, ?, ?, ?, ?, ?)
on conflict(book_id, chapter_id, chunk_index)
do update set
text_en = excluded.text_en,
text_zh = excluded.text_zh,
char_count = excluded.char_count
""",
(book_id, chapter_id, i, en_chunk, zh_chunk, len(en_chunk)),
)
conn.commit()
conn.close()
def process_book(db_path: str, book_id: str) -> List[Tuple[str, List[Tuple[str, str]]]]:
"""
Process book chapter by chapter, respecting chapter boundaries
Returns list of (chapter_id, chapter_chunks) tuples
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"""
select distinct chapter_id
from paragraphs
where book_id = ?
order by chapter_id
""",
(book_id,),
)
chapter_ids = [row[0] for row in cursor.fetchall()]
all_chapter_chunks = []
for chapter_id in chapter_ids:
en_texts, zh_texts = get_chapter_paragraphs(cursor, book_id, chapter_id)
if en_texts and zh_texts: # skip empty chapters
chapter_chunks = create_chunks(en_texts, zh_texts)
all_chapter_chunks.append((chapter_id, chapter_chunks))
conn.close()
return all_chapter_chunks
def process_all_books(db_path: str):
"""Process and store chunks for all books in database"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("select book_id from books")
book_ids = [row[0] for row in cursor.fetchall()]
conn.close()
for book_id in book_ids:
print(f"Processing and storing book: {book_id}")
store_book_chunks(db_path, book_id)
if __name__ == "__main__":
import sys
if len(sys.argv) == 3 and sys.argv[1] == "--store":
db_path = sys.argv[2]
process_all_books(db_path)
else:
# test
test_en = [
"On it were words left by Wen Jin's parents:",
"【We learned from the news that you two got married.",
"Take care of each other in the future, if you need anything,",
"talk to us, even though you may not need to.",
"From Mom and Dad.】",
"After reading this, Wen Jin felt:",
"A complex mix of emotions surged through him.",
'Returning home with the parcels, Jiang Wan asked him, "Should the shoes be unpacked?"',
]
test_zh = ["zh" + str(i) for i in range(len(test_en))]
chunks = create_chunks(test_en, test_zh, target_size=1024)
for i, (en, zh) in enumerate(chunks, 1):
print(f"\nChunk {i}:")
print(en)
print("-" * 40)