chore: up
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from config import DataConfig
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig):
|
||||
self.tokenizer = tokenizer
|
||||
self._template = template
|
||||
self.data_config = data_config
|
||||
# self._template = template
|
||||
|
||||
def load_dataset(self, path: str) -> Dataset:
|
||||
"""Load dataset from local path or Google Drive"""
|
||||
@@ -28,7 +30,17 @@ class DataLoader:
|
||||
|
||||
try:
|
||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||
return self.process_dataset(dataset)
|
||||
|
||||
if max_size := self.data_config.max_samples is not None:
|
||||
dataset = dataset.select(range(min(len(dataset), max_size)))
|
||||
|
||||
processed_dataset = self.process_dataset(dataset)
|
||||
# train/test split
|
||||
split_dataset = processed_dataset.train_test_split(
|
||||
test_size=(1 - self.data_config.train_split), shuffle=False
|
||||
)
|
||||
|
||||
return split_dataset
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading dataset: {e}")
|
||||
|
||||
@@ -41,7 +53,7 @@ class DataLoader:
|
||||
texts: list[str] = []
|
||||
for input, output in zip(inputs, outputs):
|
||||
text = (
|
||||
self._template.format(input=input, output=output)
|
||||
self.data_config.template.format(input=input, output=output)
|
||||
+ self.tokenizer.eos_token
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
Reference in New Issue
Block a user