Files
unsloth-train-scripts/data_loader.py
2025-02-13 21:42:03 +06:00

51 lines
1.8 KiB
Python

import os
from typing import Any
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizer
class DataLoader:
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
self.tokenizer = tokenizer
self._template = template
def load_dataset(self, path: str) -> Dataset:
"""Load dataset from local path or Google Drive"""
if "drive.google.com" in str(path):
try:
import gdown
local_path = "downloaded_dataset.json"
if not os.path.exists(local_path):
gdown.download(url=path, output=local_path, fuzzy=True)
dataset_path = local_path
except ImportError:
raise ImportError("Please install gdown: pip install gdown")
except Exception as e:
raise Exception(f"Error downloading from Google Drive: {e}")
else:
dataset_path = path
try:
dataset = load_dataset("json", data_files=dataset_path, split="train")
return self.process_dataset(dataset)
except Exception as e:
raise Exception(f"Error loading dataset: {e}")
def process_dataset(self, dataset: Dataset) -> Dataset:
"""Process and format the dataset"""
def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]:
inputs: list[str] = examples["input"]
outputs: list[str] = examples["output"]
texts: list[str] = []
for input, output in zip(inputs, outputs):
text = (
self._template.format(input=input, output=output)
+ self.tokenizer.eos_token
)
texts.append(text)
return {"text": texts}
return dataset.map(formatting_func, batched=True)