diff --git a/README.md b/README.md index 4e5a28a..aaa6d14 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,19 @@ class TrainingConfig: ## Usage +``` +gdown --fuzzy 'https://drive.google.com/file/d/1mqhq69dsSOK7ep7trTjRd3FMagTFTrzF/view?usp=sharing' +``` + +```bash +python main.py \ + --dataset /workspace/dataset_v3.0_alpaca_noinstr_filtered_6144.json \ + --output_dir /workspace/output \ + --wandb_project qwen2.5-lora \ + --wandb_entity luwoyuki-zhtl \ + --wandb_sweep +``` + ```bash python main.py \ --base_model mistralai/Mistral-7B-v0.1 \ diff --git a/config.py b/config.py index 6295d2d..5a54585 100644 --- a/config.py +++ b/config.py @@ -4,17 +4,17 @@ from dataclasses import dataclass, field @dataclass class DataConfig: template: str = """Translate this Chinese text to English: -{input} +{} --- Translation: -{output}""" +{}""" train_split: float = 0.9 - max_samples: int | None = None + max_samples: int | None = 3000 @dataclass class WandBConfig: - enabled: bool = True + enabled: bool = False project: str | None = None name: str | None = None entity: str | None = None @@ -34,8 +34,8 @@ class TrainingConfig: load_in_4bit: bool = False # LoRA - lora_r: int = 256 - lora_alpha: int = 512 + lora_r: int = 64 + lora_alpha: int = 128 lora_dropout: float = 0 target_modules: list[str] = field( default_factory=lambda: [ @@ -56,7 +56,7 @@ class TrainingConfig: # training args per_device_train_batch_size: int = 16 gradient_accumulation_steps: int = 2 - warmup_ratio: float = 0.03 + warmup_ratio: float = 0.05 max_grad_norm: float = 1.0 num_train_epochs: float = 1 learning_rate: float = 5e-4 @@ -74,11 +74,11 @@ class TrainingConfig: packing: bool = True # eval - fp16_full_eval = True - per_device_eval_batch_size = 2 - eval_accumulation_steps = 4 - eval_strategy = "steps" - eval_steps = 1 + fp16_full_eval: bool = True + per_device_eval_batch_size: int = 16 + eval_accumulation_steps: int = 4 + eval_strategy: str = "steps" + eval_steps: int = 5 # output output_dir: str = "/workspace/output/" diff --git a/data_loader.py b/data_loader.py index ab788fb..67728eb 100644 --- a/data_loader.py +++ b/data_loader.py @@ -31,10 +31,19 @@ class DataLoader: try: dataset = load_dataset("json", data_files=dataset_path, split="train") - if max_size := self.data_config.max_samples is not None: + print(self.data_config) + print(f"Dataset size before processing: {len(dataset)}") + + if (max_size := self.data_config.max_samples) is not None: dataset = dataset.select(range(min(len(dataset), max_size))) + # dataset.save_to_disk("/workspace/truncated_dataset") + processed_dataset = self.process_dataset(dataset) + print(f"Dataset size after processing: {len(processed_dataset)}") + + # processed_dataset.save_to_disk("/workspace/processed_dataset") + # train/test split split_dataset = processed_dataset.train_test_split( test_size=(1 - self.data_config.train_split), shuffle=False @@ -47,15 +56,15 @@ class DataLoader: def process_dataset(self, dataset: Dataset) -> Dataset: """Process and format the dataset""" - def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]: - inputs: list[str] = examples["input"] - outputs: list[str] = examples["output"] - texts: list[str] = [] - for input, output in zip(inputs, outputs): - text = ( - self.data_config.template.format(input=input, output=output) - + self.tokenizer.eos_token - ) + EOS_TOKEN = self.tokenizer.eos_token + template = self.data_config.template + + def formatting_func(examples): + inputs = examples["input"] + outputs = examples["output"] + texts = [] + for inp, out in zip(inputs, outputs): + text = template.format(inp, out) + EOS_TOKEN texts.append(text) return {"text": texts} diff --git a/main.py b/main.py index 07d76aa..0fcad02 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,22 @@ from config import TrainingConfig, WandBConfig from data_loader import DataLoader from model_handler import ModelHandler from trainer import CustomTrainer -from dataclasses import asdict +from dataclasses import asdict, fields, is_dataclass +from typing import Type + + +def reconstruct_dataclass(cls: Type, data: dict): + if not is_dataclass(cls): + raise ValueError(f"{cls} is not a dataclass") + + fieldtypes = {f.name: f.type for f in fields(cls)} + kwargs = {} + for name, value in data.items(): + if name in fieldtypes and is_dataclass(fieldtypes[name]): + kwargs[name] = reconstruct_dataclass(fieldtypes[name], value) + else: + kwargs[name] = value + return cls(**kwargs) def parse_args(): @@ -74,15 +89,10 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): "min": 1e-5, "max": 1e-3, }, - "lora_r": {"values": [8, 16, 32, 64]}, - "lora_alpha": {"values": [16, 32, 64, 128]}, + "lora_r": {"values": [32, 64]}, + # "lora_alpha": {"values": [16, 32, 64, 128]}, "per_device_train_batch_size": {"values": [16]}, "gradient_accumulation_steps": {"values": [2, 4, 8]}, - "warmup_ratio": { - "distribution": "uniform", - "min": 0.006, - "max": 0.015, - }, "num_train_epochs": {"values": [1]}, }, "early_terminate": {"type": "hyperband", "min_iter": 100}, @@ -91,7 +101,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project) def compute_scaled_config(config_dict: dict) -> dict: - base_batch_size = 16 # ref point + base_batch_size = config_dict["per_device_train_batch_size"] # 16 effective_batch_size = ( config_dict["per_device_train_batch_size"] * config_dict["gradient_accumulation_steps"] @@ -107,6 +117,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): config_dict["effective_batch_size"] = effective_batch_size config_dict["lr_scale_factor"] = lr_scale_factor + config_dict["lora_alpha"] = 2 * wandb.config.lora_r del config_dict["base_learning_rate"] @@ -128,6 +139,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): "effective_batch_size": config_dict["effective_batch_size"], "lr_scale_factor": config_dict["lr_scale_factor"], "scaled_learning_rate": config_dict["learning_rate"], + "lora_alpha": config_dict["lora_alpha"], } ) @@ -135,6 +147,8 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): config_dict.pop("lr_scale_factor") run_config = TrainingConfig(**config_dict) + run_config = reconstruct_dataclass(TrainingConfig, config_dict) + print(run_config) train_single_run(run_config, dataset_path)