chore: ss
This commit is contained in:
@@ -31,10 +31,19 @@ class DataLoader:
|
||||
try:
|
||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||
|
||||
if max_size := self.data_config.max_samples is not None:
|
||||
print(self.data_config)
|
||||
print(f"Dataset size before processing: {len(dataset)}")
|
||||
|
||||
if (max_size := self.data_config.max_samples) is not None:
|
||||
dataset = dataset.select(range(min(len(dataset), max_size)))
|
||||
|
||||
# dataset.save_to_disk("/workspace/truncated_dataset")
|
||||
|
||||
processed_dataset = self.process_dataset(dataset)
|
||||
print(f"Dataset size after processing: {len(processed_dataset)}")
|
||||
|
||||
# processed_dataset.save_to_disk("/workspace/processed_dataset")
|
||||
|
||||
# train/test split
|
||||
split_dataset = processed_dataset.train_test_split(
|
||||
test_size=(1 - self.data_config.train_split), shuffle=False
|
||||
@@ -47,15 +56,15 @@ class DataLoader:
|
||||
def process_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Process and format the dataset"""
|
||||
|
||||
def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]:
|
||||
inputs: list[str] = examples["input"]
|
||||
outputs: list[str] = examples["output"]
|
||||
texts: list[str] = []
|
||||
for input, output in zip(inputs, outputs):
|
||||
text = (
|
||||
self.data_config.template.format(input=input, output=output)
|
||||
+ self.tokenizer.eos_token
|
||||
)
|
||||
EOS_TOKEN = self.tokenizer.eos_token
|
||||
template = self.data_config.template
|
||||
|
||||
def formatting_func(examples):
|
||||
inputs = examples["input"]
|
||||
outputs = examples["output"]
|
||||
texts = []
|
||||
for inp, out in zip(inputs, outputs):
|
||||
text = template.format(inp, out) + EOS_TOKEN
|
||||
texts.append(text)
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user