161 lines
4.7 KiB
Python
161 lines
4.7 KiB
Python
import sqlite3
|
|
import json
|
|
from typing import List, Tuple
|
|
import unicodedata
|
|
import argparse
|
|
import random
|
|
|
|
|
|
def get_chapters(cursor, max_length: int) -> List[Tuple[str, str]]:
|
|
"""Get chapters with English text length less than max_length"""
|
|
query = """
|
|
select text_en, text_zh
|
|
from chapters
|
|
where length(text_en) < ?
|
|
"""
|
|
return cursor.execute(query, (max_length,)).fetchall()
|
|
|
|
|
|
# def should_join_lines(line: str) -> bool:
|
|
# """Check if line should be joined with next line based on ending"""
|
|
# line = line.rstrip()
|
|
# return line.endswith(",") or (
|
|
# line.count('"') % 2 == 1
|
|
# ) # odd number of quotes means open quote
|
|
|
|
|
|
def process_text(text: str) -> str:
|
|
# remove BOM chars
|
|
text = text.replace("\ufeff", "")
|
|
text = unicodedata.normalize("NFKC", text)
|
|
|
|
# strip leading/trailing \n
|
|
lines = text.strip().split("\n")
|
|
processed_lines = []
|
|
current_group = []
|
|
in_marked_section = False
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
# preserve empty lines
|
|
if not in_marked_section:
|
|
processed_lines.append("")
|
|
continue
|
|
|
|
if line.startswith("#<#"):
|
|
in_marked_section = True
|
|
first_line = line[3:].strip()
|
|
current_group.append(first_line)
|
|
continue
|
|
|
|
if line.endswith("#>#"):
|
|
last_line = line[:-3].strip()
|
|
current_group.append(last_line)
|
|
processed_lines.append(" ".join(current_group))
|
|
current_group = []
|
|
in_marked_section = False
|
|
continue
|
|
|
|
if in_marked_section:
|
|
current_group.append(line)
|
|
else:
|
|
processed_lines.append(line)
|
|
|
|
if current_group:
|
|
processed_lines.append(" ".join(current_group))
|
|
|
|
return "\n".join(processed_lines)
|
|
|
|
|
|
def create_datasets(
|
|
db_path: str,
|
|
output_path: str,
|
|
val_split: float = 0.0,
|
|
max_length: int = 36000,
|
|
shuffle: bool = False,
|
|
seed: int = None,
|
|
):
|
|
if seed is not None:
|
|
random.seed(seed)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
chapters = get_chapters(cursor, max_length)
|
|
|
|
if shuffle:
|
|
random.shuffle(chapters)
|
|
|
|
# split into train and validation sets
|
|
val_size = int(len(chapters) * val_split)
|
|
train_chapters = chapters[val_size:]
|
|
val_chapters = chapters[:val_size]
|
|
|
|
# Helper function to write datasets
|
|
def write_dataset(chapters: List[Tuple[str, str]], filepath: str):
|
|
with open(filepath, "w", encoding="utf-8") as f:
|
|
for text_en, text_zh in chapters:
|
|
processed_en = process_text(text_en)
|
|
processed_zh = process_text(text_zh)
|
|
|
|
entry = {
|
|
"text": f"<|im_start|>user\n{processed_zh}<|im_end|>\n<|im_start|>assistant\n{processed_en}<|im_end|>"
|
|
}
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
# Write train dataset
|
|
write_dataset(train_chapters, output_path)
|
|
|
|
# Write validation dataset if val_split > 0
|
|
if val_split > 0:
|
|
val_path = output_path.rsplit(".", 1)
|
|
val_path = f"{val_path[0]}_val.{val_path[1]}"
|
|
write_dataset(val_chapters, val_path)
|
|
|
|
print(f"Created dataset with {len(train_chapters)} training examples")
|
|
if val_split > 0:
|
|
print(f"Created validation set with {len(val_chapters)} examples")
|
|
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Create parallel text dataset")
|
|
parser.add_argument("--db-path", required=True, help="Path to SQLite database")
|
|
parser.add_argument("--output", required=True, help="Output path for the dataset")
|
|
parser.add_argument(
|
|
"--max-length",
|
|
type=int,
|
|
default=36000,
|
|
help="Maximum length for English text (default: 36000)",
|
|
)
|
|
parser.add_argument("--shuffle", action="store_true", help="Shuffle the chapters")
|
|
parser.add_argument("--seed", type=int, help="Random seed for shuffling")
|
|
parser.add_argument(
|
|
"--val-split",
|
|
type=float,
|
|
default=0.0,
|
|
help="Percentage of data to use for validation (default: 0.0)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.val_split < 0 or args.val_split >= 1:
|
|
parser.error("Validation split must be between 0 and 1")
|
|
|
|
create_datasets(
|
|
db_path=args.db_path,
|
|
output_path=args.output,
|
|
val_split=args.val_split,
|
|
max_length=args.max_length,
|
|
shuffle=args.shuffle,
|
|
seed=args.seed,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|