90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
import sqlite3
|
|
import json
|
|
import random
|
|
from typing import List, Dict, Any
|
|
from pathlib import Path
|
|
|
|
|
|
def create_alpaca_dataset(
|
|
db_path: str, output_path: str, samples_per_book: int = 155
|
|
) -> None:
|
|
"""
|
|
Create an Alpaca-style JSON dataset for Chinese to English translation.
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database
|
|
output_path: Path where the JSON dataset will be saved
|
|
samples_per_book: Maximum number of samples to take from each book_id
|
|
"""
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"select distinct book_id from paragraph_chunks where text_en is not null and text_zh is not null"
|
|
)
|
|
book_ids = [row[0] for row in cursor.fetchall()]
|
|
|
|
dataset: List[Dict[str, Any]] = []
|
|
|
|
for book_id in book_ids:
|
|
# get samples for current book_id
|
|
cursor.execute(
|
|
"""
|
|
select text_zh, text_en
|
|
from paragraph_chunks
|
|
where book_id = ?
|
|
and text_en is not null
|
|
and text_zh is not null
|
|
and length(text_zh) > 0
|
|
and length(text_en) > 0
|
|
""",
|
|
(book_id,),
|
|
)
|
|
|
|
samples = cursor.fetchall()
|
|
if not samples:
|
|
continue
|
|
selected_samples = random.sample(samples, min(len(samples), samples_per_book))
|
|
# Alpaca foramt
|
|
for zh_text, en_text in selected_samples:
|
|
entry = {
|
|
"instruction": "Translate the following Chinese text to English:",
|
|
"input": zh_text.strip(),
|
|
"output": en_text.strip(),
|
|
}
|
|
dataset.append(entry)
|
|
|
|
conn.close()
|
|
random.shuffle(dataset)
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"Dataset created successfully with {len(dataset)} total samples")
|
|
print(f"Number of unique books: {len(book_ids)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate Alpaca-style translation dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--db_path", type=str, required=True, help="Path to SQLite database"
|
|
)
|
|
parser.add_argument(
|
|
"--output_path", type=str, required=True, help="Path for output JSON file"
|
|
)
|
|
parser.add_argument(
|
|
"--samples_per_book",
|
|
type=int,
|
|
default=155,
|
|
help="Maximum number of samples to take from each book_id",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
create_alpaca_dataset(args.db_path, args.output_path, args.samples_per_book)
|