Files
zh-en-wn-dataset/sequence_len.py
2025-02-13 18:37:06 +06:00

60 lines
1.8 KiB
Python

from torchtune.data import Message
from torchtune.models.qwen2 import qwen2_tokenizer
from prompts.translation import TranslateTemplate
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
import json
def analyze_sequence_lengths(vocab_path, merges_path, json_path):
# Load Qwen2 tokenizer
tokenizer = qwen2_tokenizer(vocab_path, merges_path)
translate_template = TranslateTemplate()
with open(json_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
max_len = 0
lengths = []
for sample in tqdm(dataset):
# Convert sample to messages
msgs = [
Message(role="user", content=sample["input"]),
Message(role="assistant", content=sample["output"]),
]
templated_msgs = translate_template(msgs)
# Tokenize messages
tokens, mask = tokenizer.tokenize_messages(templated_msgs)
seq_len = len(tokens)
lengths.append(seq_len)
max_len = max(max_len, seq_len)
avg_len = sum(lengths) / len(lengths)
print(f"\nDataset size: {len(dataset)} samples")
print(f"Maximum sequence length: {max_len}")
print(f"Average sequence length: {avg_len:.2f}")
# Optional: Plot distribution
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.hist(lengths, bins=50)
plt.title("Distribution of Sequence Lengths")
plt.xlabel("Sequence Length")
plt.ylabel("Count")
plt.savefig("sequence_lengths.png") # or .jpg
plt.close()
return max_len, lengths
# Example usage
vocab_path = "/home/mira/models/Qwen2.5-7B-Base/vocab.json"
merges_path = "/home/mira/models/Qwen2.5-7B-Base/merges.txt"
dataset = "/home/mira/models/datasets/GuoFeng/datasets/dataset_v3.0_alpaca_noinstr.json"
max_len, lengths = analyze_sequence_lengths(vocab_path, merges_path, dataset)