Files
zh-en-wn-dataset/gen_alpaca.py
2025-02-09 14:30:25 +06:00

90 lines
2.7 KiB
Python

import sqlite3
import json
import random
from typing import List, Dict, Any
from pathlib import Path
def create_alpaca_dataset(
db_path: str, output_path: str, samples_per_book: int = 155
) -> None:
"""
Create an Alpaca-style JSON dataset for Chinese to English translation.
Args:
db_path: Path to the SQLite database
output_path: Path where the JSON dataset will be saved
samples_per_book: Maximum number of samples to take from each book_id
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"select distinct book_id from paragraph_chunks where text_en is not null and text_zh is not null"
)
book_ids = [row[0] for row in cursor.fetchall()]
dataset: List[Dict[str, Any]] = []
for book_id in book_ids:
# get samples for current book_id
cursor.execute(
"""
select text_zh, text_en
from paragraph_chunks
where book_id = ?
and text_en is not null
and text_zh is not null
and length(text_zh) > 0
and length(text_en) > 0
""",
(book_id,),
)
samples = cursor.fetchall()
if not samples:
continue
selected_samples = random.sample(samples, min(len(samples), samples_per_book))
# Alpaca foramt
for zh_text, en_text in selected_samples:
entry = {
"instruction": "Translate the following Chinese text to English:",
"input": zh_text.strip(),
"output": en_text.strip(),
}
dataset.append(entry)
conn.close()
random.shuffle(dataset)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
print(f"Dataset created successfully with {len(dataset)} total samples")
print(f"Number of unique books: {len(book_ids)}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Generate Alpaca-style translation dataset"
)
parser.add_argument(
"--db_path", type=str, required=True, help="Path to SQLite database"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Path for output JSON file"
)
parser.add_argument(
"--samples_per_book",
type=int,
default=155,
help="Maximum number of samples to take from each book_id",
)
args = parser.parse_args()
create_alpaca_dataset(args.db_path, args.output_path, args.samples_per_book)