import sqlite3 import json from typing import List, Tuple import unicodedata import argparse import random def get_chapters(cursor, max_length: int) -> List[Tuple[str, str]]: """Get chapters with English text length less than max_length""" query = """ select text_en, text_zh from chapters where length(text_en) < ? """ return cursor.execute(query, (max_length,)).fetchall() # def should_join_lines(line: str) -> bool: # """Check if line should be joined with next line based on ending""" # line = line.rstrip() # return line.endswith(",") or ( # line.count('"') % 2 == 1 # ) # odd number of quotes means open quote def process_text(text: str) -> str: # remove BOM chars text = text.replace("\ufeff", "") text = unicodedata.normalize("NFKC", text) # strip leading/trailing \n lines = text.strip().split("\n") processed_lines = [] current_group = [] in_marked_section = False for line in lines: line = line.strip() if not line: # preserve empty lines if not in_marked_section: processed_lines.append("") continue if line.startswith("#<#"): in_marked_section = True first_line = line[3:].strip() current_group.append(first_line) continue if line.endswith("#>#"): last_line = line[:-3].strip() current_group.append(last_line) processed_lines.append(" ".join(current_group)) current_group = [] in_marked_section = False continue if in_marked_section: current_group.append(line) else: processed_lines.append(line) if current_group: processed_lines.append(" ".join(current_group)) return "\n".join(processed_lines) def create_datasets( db_path: str, output_path: str, val_split: float = 0.0, max_length: int = 36000, shuffle: bool = False, seed: int = None, ): if seed is not None: random.seed(seed) conn = sqlite3.connect(db_path) cursor = conn.cursor() try: chapters = get_chapters(cursor, max_length) if shuffle: random.shuffle(chapters) # split into train and validation sets val_size = int(len(chapters) * val_split) train_chapters = chapters[val_size:] val_chapters = chapters[:val_size] # Helper function to write datasets def write_dataset(chapters: List[Tuple[str, str]], filepath: str): template = "### Instruction:\nTranslate the following Chinese text to English:\n\n### Input:\n{}\n\n### Response:\n{}" with open(filepath, "w", encoding="utf-8") as f: for text_en, text_zh in chapters: processed_en = process_text(text_en) processed_zh = process_text(text_zh) # entry = { # "text": f"<|im_start|>user\n{processed_zh}<|im_end|>\n<|im_start|>assistant\n{processed_en}<|im_end|>" # } # entry = {"text": template.format(processed_zh, processed_en)} entry = { "instruction": "Translate the following Chinese text to English:", "input": processed_zh, "output": processed_en, } f.write(json.dumps(entry, ensure_ascii=False) + "\n") # Write train dataset write_dataset(train_chapters, output_path) # Write validation dataset if val_split > 0 if val_split > 0: val_path = output_path.rsplit(".", 1) val_path = f"{val_path[0]}_val.{val_path[1]}" write_dataset(val_chapters, val_path) print(f"Created dataset with {len(train_chapters)} training examples") if val_split > 0: print(f"Created validation set with {len(val_chapters)} examples") finally: conn.close() def main(): parser = argparse.ArgumentParser(description="Create parallel text dataset") parser.add_argument("--db-path", required=True, help="Path to SQLite database") parser.add_argument("--output", required=True, help="Output path for the dataset") parser.add_argument( "--max-length", type=int, default=36000, help="Maximum length for English text (default: 36000)", ) parser.add_argument("--shuffle", action="store_true", help="Shuffle the chapters") parser.add_argument("--seed", type=int, help="Random seed for shuffling") parser.add_argument( "--val-split", type=float, default=0.0, help="Percentage of data to use for validation (default: 0.0)", ) args = parser.parse_args() if args.val_split < 0 or args.val_split >= 1: parser.error("Validation split must be between 0 and 1") create_datasets( db_path=args.db_path, output_path=args.output, val_split=args.val_split, max_length=args.max_length, shuffle=args.shuffle, seed=args.seed, ) if __name__ == "__main__": main()