from ast import Tuple import os import json import argparse from typing import Any from xml.etree.ElementTree import parse def mkdir(file_path: str): directory = os.path.dirname(file_path) if directory and not os.path.exists(directory): os.makedirs(directory) def main(): parser = argparse.ArgumentParser(description="Convert Intertext to Alpaca JSON") _ = parser.add_argument( "-i", "--input", type=str, required=True, help="Input directory for Intertext alignments.", ) _ = parser.add_argument( "-o", "--output_file", type=str, required=True, help="Output file for JSON files.", ) args = parser.parse_args() mkdir(args.output_file) all_pairs = [] align_files: list[str] = [ f for f in os.listdir(args.input) if "_zh." in f and "_en.xml" in f ] for align_file in sorted(align_files): print(f"Processing {align_file}...") align_path = os.path.join(args.input, align_file) doc = parse(align_path) root = doc.getroot() zh_file = root.get("fromDoc") en_file = root.get("toDoc") if not zh_file or not en_file: print(f"Warning: Missing fromDoc or toDoc in {align_file}") continue zh_sents = get_sents(os.path.join(args.input, zh_file)) en_sents = get_sents(os.path.join(args.input, en_file)) alignments = get_alignments(align_path) pairs = create_pairs(zh_sents, en_sents, alignments) all_pairs.extend(pairs) write_json(all_pairs, args.output_file) print(f"Created {args.output_file} with {len(all_pairs)} pairs") def create_pairs(zh_sents, en_sents, alignments): pairs = [] for zh_idx, en_idx in alignments: zh_sent = find_sent_by_id(zh_idx, zh_sents) en_sent = find_sent_by_id(en_idx, en_sents) if zh_sent and en_sent: # both sentences should exist pair = {"input": zh_sent, "output": en_sent} pairs.append(pair) return pairs def write_json(pairs, out_file): with open(out_file, "w", encoding="utf-8") as f: json.dump(pairs, f, ensure_ascii=False, indent=2) def find_sent_by_id(idx, sents): sent = "" if len(idx) > 0: sent = " ".join(sents[idx[0] : idx[-1] + 1]) return sent.strip() def get_alignments(file: str): doc = parse(file) links = [] for link in doc.iterfind("link"): if xtargets := link.get("xtargets"): en_link, zh_link = xtargets.split(";") zh_bead = parse_link(zh_link) en_bead = parse_link(en_link) links.append((zh_bead, en_bead)) return links def parse_link(link) -> list[int]: bead = [] if len(link) > 0: bead = [int(item.split(":")[1]) - 1 for item in link.split(" ")] return bead def get_sents(file: str) -> list[str]: doc = parse(file) sents = [] for sent in doc.iterfind("p/s"): sents.append(sent.text) return sents if __name__ == "__main__": main()