chore: json dataset
This commit is contained in:
89
gen_alpaca.py
Normal file
89
gen_alpaca.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import random
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def create_alpaca_dataset(
|
||||
db_path: str, output_path: str, samples_per_book: int = 155
|
||||
) -> None:
|
||||
"""
|
||||
Create an Alpaca-style JSON dataset for Chinese to English translation.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database
|
||||
output_path: Path where the JSON dataset will be saved
|
||||
samples_per_book: Maximum number of samples to take from each book_id
|
||||
"""
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"select distinct book_id from paragraph_chunks where text_en is not null and text_zh is not null"
|
||||
)
|
||||
book_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
dataset: List[Dict[str, Any]] = []
|
||||
|
||||
for book_id in book_ids:
|
||||
# get samples for current book_id
|
||||
cursor.execute(
|
||||
"""
|
||||
select text_zh, text_en
|
||||
from paragraph_chunks
|
||||
where book_id = ?
|
||||
and text_en is not null
|
||||
and text_zh is not null
|
||||
and length(text_zh) > 0
|
||||
and length(text_en) > 0
|
||||
""",
|
||||
(book_id,),
|
||||
)
|
||||
|
||||
samples = cursor.fetchall()
|
||||
if not samples:
|
||||
continue
|
||||
selected_samples = random.sample(samples, min(len(samples), samples_per_book))
|
||||
# Alpaca foramt
|
||||
for zh_text, en_text in selected_samples:
|
||||
entry = {
|
||||
"instruction": "Translate the following Chinese text to English:",
|
||||
"input": zh_text.strip(),
|
||||
"output": en_text.strip(),
|
||||
}
|
||||
dataset.append(entry)
|
||||
|
||||
conn.close()
|
||||
random.shuffle(dataset)
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Dataset created successfully with {len(dataset)} total samples")
|
||||
print(f"Number of unique books: {len(book_ids)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate Alpaca-style translation dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db_path", type=str, required=True, help="Path to SQLite database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Path for output JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--samples_per_book",
|
||||
type=int,
|
||||
default=155,
|
||||
help="Maximum number of samples to take from each book_id",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
create_alpaca_dataset(args.db_path, args.output_path, args.samples_per_book)
|
||||
Reference in New Issue
Block a user