Files
zh-en-wn-dataset/parallel_text_import.py
2025-02-09 03:07:07 +06:00

120 lines
3.3 KiB
Python

import sqlite3
import re
from typing import List, Tuple, Dict
from dataclasses import dataclass
@dataclass
class TextUnit:
book_id: str
chapter_id: str
text: str
def parse_file(filename: str) -> List[TextUnit]:
"""Parse the file and return a list of TextUnits."""
units = []
current_book = ""
current_chapter = ""
current_text = []
book_pattern = re.compile(r'<BOOK id="([^"]+)">')
chapter_pattern = re.compile(r'<CHAPTER id="([^"]+)">')
end_pattern = re.compile(r"</(?:BOOK|CHAPTER)>")
with open(filename, "r", encoding="utf-8") as f:
for line in f:
# parse BOOK opening tag
book_match = book_pattern.match(line)
if book_match:
current_book = book_match.group(1)
continue
# parse CHAPTER opening tag
chapter_match = chapter_pattern.match(line)
if chapter_match:
current_chapter = chapter_match.group(1)
current_text = []
continue
# on any end tag, save the current chapter
if end_pattern.match(line):
if current_text:
units.append(
TextUnit(
book_id=current_book,
chapter_id=current_chapter,
text="".join(current_text),
)
)
continue
# if line doesn't match any of our known tags, it's content
if not book_pattern.match(line) and not chapter_pattern.match(line):
current_text.append(line)
return units
def create_database(db_name: str = "parallel_texts.db"):
"""create schema"""
conn = sqlite3.connect(db_name)
with open("schema.sql", "r") as f:
conn.executescript(f.read())
conn.commit()
return conn
def import_texts(
en_units: List[TextUnit], zh_units: List[TextUnit], conn: sqlite3.Connection
):
"""import parsed text"""
c = conn.cursor()
# collect all unique book IDs
book_ids = set(unit.book_id for unit in en_units)
# insert books
for book_id in book_ids:
c.execute("insert or ignore into books (book_id) values (?)", (book_id,))
# create a dict for Chinese texts
zh_dict = {(unit.book_id, unit.chapter_id): unit.text for unit in zh_units}
# insert chapters with parallel texts
for en_unit in en_units:
zh_text = zh_dict.get((en_unit.book_id, en_unit.chapter_id), "")
c.execute(
"""
insert or replace into chapters (book_id, chapter_id, text_en, text_zh) values (?, ?, ?, ?)
""",
(en_unit.book_id, en_unit.chapter_id, en_unit.text, zh_text),
)
conn.commit()
def main():
en_units = parse_file("train.en")
zh_units = parse_file("train.zh")
# create and populate database
conn = create_database()
import_texts(en_units, zh_units, conn)
# stats
c = conn.cursor()
c.execute("select count(*) from books")
book_count = c.fetchone()[0]
c.execute("select count(*) from chapters")
chapter_count = c.fetchone()[0]
print(f"Imported {book_count} books and {chapter_count} chapters.")
conn.close()
if __name__ == "__main__":
main()