from torchtune.data import Message from torchtune.models.qwen2 import qwen2_tokenizer from prompts.translation import TranslateTemplate from tqdm import tqdm from tqdm.contrib.concurrent import process_map import json def analyze_sequence_lengths(vocab_path, merges_path, json_path): # Load Qwen2 tokenizer tokenizer = qwen2_tokenizer(vocab_path, merges_path) translate_template = TranslateTemplate() with open(json_path, "r", encoding="utf-8") as f: dataset = json.load(f) max_len = 0 lengths = [] for sample in tqdm(dataset): # Convert sample to messages msgs = [ Message(role="user", content=sample["input"]), Message(role="assistant", content=sample["output"]), ] templated_msgs = translate_template(msgs) # Tokenize messages tokens, mask = tokenizer.tokenize_messages(templated_msgs) seq_len = len(tokens) lengths.append(seq_len) max_len = max(max_len, seq_len) avg_len = sum(lengths) / len(lengths) print(f"\nDataset size: {len(dataset)} samples") print(f"Maximum sequence length: {max_len}") print(f"Average sequence length: {avg_len:.2f}") # Optional: Plot distribution import matplotlib.pyplot as plt plt.figure(figsize=(10, 6)) plt.hist(lengths, bins=50) plt.title("Distribution of Sequence Lengths") plt.xlabel("Sequence Length") plt.ylabel("Count") plt.savefig("sequence_lengths.png") # or .jpg plt.close() return max_len, lengths # Example usage vocab_path = "/home/mira/models/Qwen2.5-7B-Base/vocab.json" merges_path = "/home/mira/models/Qwen2.5-7B-Base/merges.txt" dataset = "/home/mira/models/datasets/GuoFeng/datasets/dataset_v3.0_alpaca_noinstr.json" max_len, lengths = analyze_sequence_lengths(vocab_path, merges_path, dataset)