chore: up

This commit is contained in:
2025-02-14 21:50:55 +06:00
parent 45f45f4bdb
commit efd62d7c94
4 changed files with 185 additions and 65 deletions

View File

@@ -1,13 +1,15 @@
import os
from typing import Any
from config import DataConfig
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizer
class DataLoader:
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig):
self.tokenizer = tokenizer
self._template = template
self.data_config = data_config
# self._template = template
def load_dataset(self, path: str) -> Dataset:
"""Load dataset from local path or Google Drive"""
@@ -28,7 +30,17 @@ class DataLoader:
try:
dataset = load_dataset("json", data_files=dataset_path, split="train")
return self.process_dataset(dataset)
if max_size := self.data_config.max_samples is not None:
dataset = dataset.select(range(min(len(dataset), max_size)))
processed_dataset = self.process_dataset(dataset)
# train/test split
split_dataset = processed_dataset.train_test_split(
test_size=(1 - self.data_config.train_split), shuffle=False
)
return split_dataset
except Exception as e:
raise Exception(f"Error loading dataset: {e}")
@@ -41,7 +53,7 @@ class DataLoader:
texts: list[str] = []
for input, output in zip(inputs, outputs):
text = (
self._template.format(input=input, output=output)
self.data_config.template.format(input=input, output=output)
+ self.tokenizer.eos_token
)
texts.append(text)