72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import os
|
|
from typing import Any
|
|
from config import DataConfig
|
|
from datasets import Dataset, load_dataset
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
class DataLoader:
|
|
def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig):
|
|
self.tokenizer = tokenizer
|
|
self.data_config = data_config
|
|
# self._template = template
|
|
|
|
def load_dataset(self, path: str) -> Dataset:
|
|
"""Load dataset from local path or Google Drive"""
|
|
if "drive.google.com" in str(path):
|
|
try:
|
|
import gdown
|
|
|
|
local_path = "downloaded_dataset.json"
|
|
if not os.path.exists(local_path):
|
|
gdown.download(url=path, output=local_path, fuzzy=True)
|
|
dataset_path = local_path
|
|
except ImportError:
|
|
raise ImportError("Please install gdown: pip install gdown")
|
|
except Exception as e:
|
|
raise Exception(f"Error downloading from Google Drive: {e}")
|
|
else:
|
|
dataset_path = path
|
|
|
|
try:
|
|
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
|
|
|
print(self.data_config)
|
|
print(f"Dataset size before processing: {len(dataset)}")
|
|
|
|
if (max_size := self.data_config.max_samples) is not None:
|
|
dataset = dataset.select(range(min(len(dataset), max_size)))
|
|
|
|
# dataset.save_to_disk("/workspace/truncated_dataset")
|
|
|
|
processed_dataset = self.process_dataset(dataset)
|
|
print(f"Dataset size after processing: {len(processed_dataset)}")
|
|
|
|
# processed_dataset.save_to_disk("/workspace/processed_dataset")
|
|
|
|
# train/test split
|
|
split_dataset = processed_dataset.train_test_split(
|
|
test_size=(1 - self.data_config.train_split), shuffle=False
|
|
)
|
|
|
|
return split_dataset
|
|
except Exception as e:
|
|
raise Exception(f"Error loading dataset: {e}")
|
|
|
|
def process_dataset(self, dataset: Dataset) -> Dataset:
|
|
"""Process and format the dataset"""
|
|
|
|
EOS_TOKEN = self.tokenizer.eos_token
|
|
template = self.data_config.template
|
|
|
|
def formatting_func(examples):
|
|
inputs = examples["input"]
|
|
outputs = examples["output"]
|
|
texts = []
|
|
for inp, out in zip(inputs, outputs):
|
|
text = template.format(inp, out) + EOS_TOKEN
|
|
texts.append(text)
|
|
return {"text": texts}
|
|
|
|
return dataset.map(formatting_func, batched=True)
|