chore: trainer
This commit is contained in:
50
data_loader.py
Normal file
50
data_loader.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DataLoader:
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
|
||||
self.tokenizer = tokenizer
|
||||
self._template = template
|
||||
|
||||
def load_dataset(self, path: str) -> Dataset:
|
||||
"""Load dataset from local path or Google Drive"""
|
||||
if "drive.google.com" in str(path):
|
||||
try:
|
||||
import gdown
|
||||
|
||||
local_path = "downloaded_dataset.json"
|
||||
if not os.path.exists(local_path):
|
||||
gdown.download(url=path, output=local_path, fuzzy=True)
|
||||
dataset_path = local_path
|
||||
except ImportError:
|
||||
raise ImportError("Please install gdown: pip install gdown")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error downloading from Google Drive: {e}")
|
||||
else:
|
||||
dataset_path = path
|
||||
|
||||
try:
|
||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||
return self.process_dataset(dataset)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error loading dataset: {e}")
|
||||
|
||||
def process_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Process and format the dataset"""
|
||||
|
||||
def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]:
|
||||
inputs: list[str] = examples["input"]
|
||||
outputs: list[str] = examples["output"]
|
||||
texts: list[str] = []
|
||||
for input, output in zip(inputs, outputs):
|
||||
text = (
|
||||
self._template.format(input=input, output=output)
|
||||
+ self.tokenizer.eos_token
|
||||
)
|
||||
texts.append(text)
|
||||
return {"text": texts}
|
||||
|
||||
return dataset.map(formatting_func, batched=True)
|
||||
Reference in New Issue
Block a user