chore: more!

This commit is contained in:
2025-02-11 13:28:12 +06:00
parent 28342e0ace
commit befdc9c945
2068 changed files with 102392 additions and 908 deletions

View File

@@ -2,30 +2,34 @@ import sqlite3
import json
from typing import List, Tuple
import unicodedata
import argparse
import random
def get_chapters(cursor) -> List[Tuple[str, str]]:
"""text length < 36000"""
def get_chapters(cursor, max_length: int) -> List[Tuple[str, str]]:
"""Get chapters with English text length less than max_length"""
query = """
select text_en, text_zh
from chapters
where length(text_en) < 36000
where length(text_en) < ?
"""
return cursor.execute(query).fetchall()
return cursor.execute(query, (max_length,)).fetchall()
def should_join_lines(line: str) -> bool:
"""Check if line should be joined with next line based on ending"""
line = line.rstrip()
return line.endswith(",") or (
line.count('"') % 2 == 1
) # odd number of quotes means open quote
# def should_join_lines(line: str) -> bool:
# """Check if line should be joined with next line based on ending"""
# line = line.rstrip()
# return line.endswith(",") or (
# line.count('"') % 2 == 1
# ) # odd number of quotes means open quote
def process_text(text: str) -> str:
"""Process text by handling special markings and line breaks"""
# remove BOM chars
text = text.replace("\ufeff", "")
text = unicodedata.normalize("NFKC", text)
# strip leading/trailing \n
lines = text.strip().split("\n")
processed_lines = []
current_group = []
@@ -34,20 +38,20 @@ def process_text(text: str) -> str:
for line in lines:
line = line.strip()
if not line:
# preserve empty lines
if not in_marked_section:
processed_lines.append("")
continue
if line.startswith("#<#"):
# Start of marked section - remove marker and store first line
in_marked_section = True
first_line = line[3:].strip() # Remove #<# and whitespace
first_line = line[3:].strip()
current_group.append(first_line)
continue
if line.endswith("#>#"):
# End of marked section - remove marker and store last line
last_line = line[:-3].strip() # Remove #># and whitespace
last_line = line[:-3].strip()
current_group.append(last_line)
# Join all collected lines with space and add to processed lines
processed_lines.append(" ".join(current_group))
current_group = []
in_marked_section = False
@@ -58,38 +62,99 @@ def process_text(text: str) -> str:
else:
processed_lines.append(line)
# Handle any remaining grouped lines (in case of malformed input)
if current_group:
processed_lines.append(" ".join(current_group))
# Join with double newlines
return "\n\n".join(processed_lines)
return "\n".join(processed_lines)
def create_dataset(db_path: str, output_path: str):
def create_datasets(
db_path: str,
output_path: str,
val_split: float = 0.0,
max_length: int = 36000,
shuffle: bool = False,
seed: int = None,
):
if seed is not None:
random.seed(seed)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
try:
chapters = get_chapters(cursor)
chapters = get_chapters(cursor, max_length)
with open(output_path, "w", encoding="utf-8") as f:
for text_en, text_zh in chapters:
processed_en = process_text(text_en)
processed_zh = process_text(text_zh)
if shuffle:
random.shuffle(chapters)
entry = {
"text": f"<|im_start|>user\n{processed_zh}<|im_end|>\n<|im_start|>assistant\n{processed_en}<|im_end|>"
}
# split into train and validation sets
val_size = int(len(chapters) * val_split)
train_chapters = chapters[val_size:]
val_chapters = chapters[:val_size]
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# Helper function to write datasets
def write_dataset(chapters: List[Tuple[str, str]], filepath: str):
with open(filepath, "w", encoding="utf-8") as f:
for text_en, text_zh in chapters:
processed_en = process_text(text_en)
processed_zh = process_text(text_zh)
entry = {
"text": f"<|im_start|>user\n{processed_zh}<|im_end|>\n<|im_start|>assistant\n{processed_en}<|im_end|>"
}
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# Write train dataset
write_dataset(train_chapters, output_path)
# Write validation dataset if val_split > 0
if val_split > 0:
val_path = output_path.rsplit(".", 1)
val_path = f"{val_path[0]}_val.{val_path[1]}"
write_dataset(val_chapters, val_path)
print(f"Created dataset with {len(train_chapters)} training examples")
if val_split > 0:
print(f"Created validation set with {len(val_chapters)} examples")
finally:
conn.close()
if __name__ == "__main__":
DB_PATH = "parallel_texts.db"
OUTPUT_PATH = "datasets/dataset_v1.jsonl"
def main():
parser = argparse.ArgumentParser(description="Create parallel text dataset")
parser.add_argument("--db-path", required=True, help="Path to SQLite database")
parser.add_argument("--output", required=True, help="Output path for the dataset")
parser.add_argument(
"--max-length",
type=int,
default=36000,
help="Maximum length for English text (default: 36000)",
)
parser.add_argument("--shuffle", action="store_true", help="Shuffle the chapters")
parser.add_argument("--seed", type=int, help="Random seed for shuffling")
parser.add_argument(
"--val-split",
type=float,
default=0.0,
help="Percentage of data to use for validation (default: 0.0)",
)
create_dataset(DB_PATH, OUTPUT_PATH)
args = parser.parse_args()
if args.val_split < 0 or args.val_split >= 1:
parser.error("Validation split must be between 0 and 1")
create_datasets(
db_path=args.db_path,
output_path=args.output,
val_split=args.val_split,
max_length=args.max_length,
shuffle=args.shuffle,
seed=args.seed,
)
if __name__ == "__main__":
main()