chore: ss

This commit is contained in:
2025-02-14 23:53:46 +06:00
parent efd62d7c94
commit 39e69c90b1
4 changed files with 67 additions and 31 deletions

View File

@@ -31,10 +31,19 @@ class DataLoader:
try:
dataset = load_dataset("json", data_files=dataset_path, split="train")
if max_size := self.data_config.max_samples is not None:
print(self.data_config)
print(f"Dataset size before processing: {len(dataset)}")
if (max_size := self.data_config.max_samples) is not None:
dataset = dataset.select(range(min(len(dataset), max_size)))
# dataset.save_to_disk("/workspace/truncated_dataset")
processed_dataset = self.process_dataset(dataset)
print(f"Dataset size after processing: {len(processed_dataset)}")
# processed_dataset.save_to_disk("/workspace/processed_dataset")
# train/test split
split_dataset = processed_dataset.train_test_split(
test_size=(1 - self.data_config.train_split), shuffle=False
@@ -47,15 +56,15 @@ class DataLoader:
def process_dataset(self, dataset: Dataset) -> Dataset:
"""Process and format the dataset"""
def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]:
inputs: list[str] = examples["input"]
outputs: list[str] = examples["output"]
texts: list[str] = []
for input, output in zip(inputs, outputs):
text = (
self.data_config.template.format(input=input, output=output)
+ self.tokenizer.eos_token
)
EOS_TOKEN = self.tokenizer.eos_token
template = self.data_config.template
def formatting_func(examples):
inputs = examples["input"]
outputs = examples["output"]
texts = []
for inp, out in zip(inputs, outputs):
text = template.format(inp, out) + EOS_TOKEN
texts.append(text)
return {"text": texts}