import os from typing import Any from config import DataConfig from datasets import Dataset, load_dataset from transformers import PreTrainedTokenizer class DataLoader: def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig): self.tokenizer = tokenizer self.data_config = data_config # self._template = template def load_dataset(self, path: str) -> Dataset: """Load dataset from local path or Google Drive""" if "drive.google.com" in str(path): try: import gdown local_path = "downloaded_dataset.json" if not os.path.exists(local_path): gdown.download(url=path, output=local_path, fuzzy=True) dataset_path = local_path except ImportError: raise ImportError("Please install gdown: pip install gdown") except Exception as e: raise Exception(f"Error downloading from Google Drive: {e}") else: dataset_path = path try: dataset = load_dataset("json", data_files=dataset_path, split="train") print(self.data_config) print(f"Dataset size before processing: {len(dataset)}") if (max_size := self.data_config.max_samples) is not None: dataset = dataset.select(range(min(len(dataset), max_size))) # dataset.save_to_disk("/workspace/truncated_dataset") processed_dataset = self.process_dataset(dataset) print(f"Dataset size after processing: {len(processed_dataset)}") # processed_dataset.save_to_disk("/workspace/processed_dataset") # train/test split split_dataset = processed_dataset.train_test_split( test_size=(1 - self.data_config.train_split), shuffle=False ) return split_dataset except Exception as e: raise Exception(f"Error loading dataset: {e}") def process_dataset(self, dataset: Dataset) -> Dataset: """Process and format the dataset""" EOS_TOKEN = self.tokenizer.eos_token template = self.data_config.template def formatting_func(examples): inputs = examples["input"] outputs = examples["output"] texts = [] for inp, out in zip(inputs, outputs): text = template.format(inp, out) + EOS_TOKEN texts.append(text) return {"text": texts} return dataset.map(formatting_func, batched=True)