diff --git a/config.py b/config.py index 5a54585..8f56a0e 100644 --- a/config.py +++ b/config.py @@ -8,8 +8,8 @@ class DataConfig: --- Translation: {}""" - train_split: float = 0.9 - max_samples: int | None = 3000 + train_split: float = 0.95 + max_samples: int | None = 5000 @dataclass @@ -31,7 +31,7 @@ class TrainingConfig: base_model: str = "unsloth/Qwen2.5-7B" max_seq_length: int = 6144 dtype: str | None = None - load_in_4bit: bool = False + load_in_4bit: bool = True # LoRA lora_r: int = 64 @@ -48,7 +48,7 @@ class TrainingConfig: "down_proj", ] ) - use_gradient_checkpointing: str = "unsloth" + use_gradient_checkpointing: str = True random_state: int = 3407 use_rslora: bool = False loftq_config: dict | None = None @@ -56,11 +56,11 @@ class TrainingConfig: # training args per_device_train_batch_size: int = 16 gradient_accumulation_steps: int = 2 - warmup_ratio: float = 0.05 + warmup_ratio: float = 0.1 max_grad_norm: float = 1.0 num_train_epochs: float = 1 learning_rate: float = 5e-4 - weight_decay: float = 0 + weight_decay: float = 0.01 lr_scheduler_type: str = "cosine" logging_steps: int = 1 @@ -70,15 +70,15 @@ class TrainingConfig: save_total_limit: int | None = 3 # dataset - dataset_num_proc: int = 8 + dataset_num_proc: int = 4 packing: bool = True # eval fp16_full_eval: bool = True - per_device_eval_batch_size: int = 16 - eval_accumulation_steps: int = 4 + per_device_eval_batch_size: int = 64 + eval_accumulation_steps: int = 1 eval_strategy: str = "steps" - eval_steps: int = 5 + eval_steps: int = 10 # output output_dir: str = "/workspace/output/" diff --git a/main.py b/main.py index 0fcad02..3c05ea8 100644 --- a/main.py +++ b/main.py @@ -84,72 +84,46 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str): "method": "bayes", "metric": {"name": "val_loss", "goal": "minimize"}, "parameters": { - "base_learning_rate": { + "learning_rate": { "distribution": "log_uniform_values", "min": 1e-5, - "max": 1e-3, + "max": 1e-4, }, - "lora_r": {"values": [32, 64]}, - # "lora_alpha": {"values": [16, 32, 64, 128]}, - "per_device_train_batch_size": {"values": [16]}, - "gradient_accumulation_steps": {"values": [2, 4, 8]}, + "lora_r": {"values": [32]}, + "lora_alpha": {"values": [64]}, + "per_device_train_batch_size": {"values": [64]}, + "gradient_accumulation_steps": {"values": [1]}, "num_train_epochs": {"values": [1]}, }, - "early_terminate": {"type": "hyperband", "min_iter": 100}, } sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project) - def compute_scaled_config(config_dict: dict) -> dict: - base_batch_size = config_dict["per_device_train_batch_size"] # 16 - effective_batch_size = ( - config_dict["per_device_train_batch_size"] - * config_dict["gradient_accumulation_steps"] - ) - - # apply square root scaling for Adam - lr_scale_factor = (effective_batch_size / base_batch_size) ** 0.5 - - # scale the learning rate - config_dict["learning_rate"] = ( - config_dict["base_learning_rate"] * lr_scale_factor - ) - - config_dict["effective_batch_size"] = effective_batch_size - config_dict["lr_scale_factor"] = lr_scale_factor - config_dict["lora_alpha"] = 2 * wandb.config.lora_r - - del config_dict["base_learning_rate"] - - return config_dict - def sweep_train(): with wandb.init() as run: + # Convert base config to dict config_dict = ( asdict(base_config) if hasattr(base_config, "__dataclass_fields__") else base_config.__dict__.copy() ) + # Update with sweep parameters config_dict.update(wandb.config) - config_dict = compute_scaled_config(config_dict) - wandb.log( - { - "effective_batch_size": config_dict["effective_batch_size"], - "lr_scale_factor": config_dict["lr_scale_factor"], - "scaled_learning_rate": config_dict["learning_rate"], - "lora_alpha": config_dict["lora_alpha"], - } + # Set lora_alpha based on lora_r + # config_dict["lora_alpha"] = 2 * config_dict["lora_r"] + + # Log effective batch size for monitoring + effective_batch_size = ( + config_dict["per_device_train_batch_size"] + * config_dict["gradient_accumulation_steps"] ) + wandb.log({"effective_batch_size": effective_batch_size}) - config_dict.pop("effective_batch_size") - config_dict.pop("lr_scale_factor") - - run_config = TrainingConfig(**config_dict) + # Create training config and run run_config = reconstruct_dataclass(TrainingConfig, config_dict) print(run_config) - train_single_run(run_config, dataset_path) wandb.agent(sweep_id, function=sweep_train)