#!/usr/bin/env python3 import json import argparse import random from collections.abc import Sequence from pathlib import Path from typing import TypedDict class AlpacaEntry(TypedDict): input: str output: str instruction: str | None class Args(argparse.Namespace): def __init__(self) -> None: super().__init__() self.file1: str = "" self.file2: str = "" self.output: str = "" self.shuffle: bool = False self.seed: int | None = None self.omit_instruction: bool = False def load_file(file_path: str) -> Sequence[AlpacaEntry]: """Load and validate a file in JSON or JSONL format.""" path = Path(file_path) is_jsonl = path.suffix.lower() in (".jsonl", ".ljson") with open(file_path, "r", encoding="utf-8") as f: if is_jsonl: data: list[AlpacaEntry] = [] for line in f: line = line.strip() if line: # Skip empty lines entry: AlpacaEntry = json.loads(line) data.append(entry) else: data = json.load(f) # Validate required fields for item in data: if not all(key in item for key in ["input", "output"]): raise ValueError( f"Missing required fields in {file_path}. Each item must have 'input' and 'output' fields." ) return data def write_output(data: Sequence[AlpacaEntry], output_path: str) -> None: """Write output in JSON or JSONL format based on file extension.""" path = Path(output_path) is_jsonl = path.suffix.lower() in (".jsonl", ".ljson") print(path, is_jsonl) with open(output_path, "w", encoding="utf-8") as f: if is_jsonl: for item in data: _ = f.write(json.dumps(item, ensure_ascii=False) + "\n") else: json.dump(data, f, ensure_ascii=False, indent=2) def merge_datasets( file1: str, file2: str, shuffle: bool = False, omit_instruction: bool = False ) -> Sequence[AlpacaEntry]: """Merge two files in JSON/JSONL format.""" # Load both files data1 = load_file(file1) data2 = load_file(file2) merged_data = list(data1) + list(data2) if omit_instruction: for item in merged_data: _ = item.pop("instruction", None) if shuffle: random.shuffle(merged_data) return merged_data def parse_args() -> Args: """Parse and validate command line arguments.""" parser = argparse.ArgumentParser(description="Merge two Alpaca-format JSON files") _ = parser.add_argument("file1", type=str, help="Path to first JSON file") _ = parser.add_argument("file2", type=str, help="Path to second JSON file") _ = parser.add_argument("output", type=str, help="Path to output merged JSON file") _ = parser.add_argument( "--shuffle", action="store_true", help="Shuffle the merged dataset" ) _ = parser.add_argument("--seed", type=int, help="Random seed for shuffling") _ = parser.add_argument( "--omit-instruction", action="store_true", help="Omit instruction field from output", ) return parser.parse_args(namespace=Args()) def main() -> None: args = parse_args() try: if args.seed is not None: random.seed(args.seed) merged_data = merge_datasets( args.file1, args.file2, args.shuffle, args.omit_instruction ) write_output(merged_data, args.output) print(f"Successfully merged files. Total entries: {len(merged_data)}") except Exception as e: print(f"Error: {str(e)}") exit(1) if __name__ == "__main__": main()