chore: ss

This commit is contained in:
2025-02-14 23:53:46 +06:00
parent efd62d7c94
commit 39e69c90b1
4 changed files with 67 additions and 31 deletions

32
main.py
View File

@@ -3,7 +3,22 @@ from config import TrainingConfig, WandBConfig
from data_loader import DataLoader
from model_handler import ModelHandler
from trainer import CustomTrainer
from dataclasses import asdict
from dataclasses import asdict, fields, is_dataclass
from typing import Type
def reconstruct_dataclass(cls: Type, data: dict):
if not is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")
fieldtypes = {f.name: f.type for f in fields(cls)}
kwargs = {}
for name, value in data.items():
if name in fieldtypes and is_dataclass(fieldtypes[name]):
kwargs[name] = reconstruct_dataclass(fieldtypes[name], value)
else:
kwargs[name] = value
return cls(**kwargs)
def parse_args():
@@ -74,15 +89,10 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
"min": 1e-5,
"max": 1e-3,
},
"lora_r": {"values": [8, 16, 32, 64]},
"lora_alpha": {"values": [16, 32, 64, 128]},
"lora_r": {"values": [32, 64]},
# "lora_alpha": {"values": [16, 32, 64, 128]},
"per_device_train_batch_size": {"values": [16]},
"gradient_accumulation_steps": {"values": [2, 4, 8]},
"warmup_ratio": {
"distribution": "uniform",
"min": 0.006,
"max": 0.015,
},
"num_train_epochs": {"values": [1]},
},
"early_terminate": {"type": "hyperband", "min_iter": 100},
@@ -91,7 +101,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
def compute_scaled_config(config_dict: dict) -> dict:
base_batch_size = 16 # ref point
base_batch_size = config_dict["per_device_train_batch_size"] # 16
effective_batch_size = (
config_dict["per_device_train_batch_size"]
* config_dict["gradient_accumulation_steps"]
@@ -107,6 +117,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
config_dict["effective_batch_size"] = effective_batch_size
config_dict["lr_scale_factor"] = lr_scale_factor
config_dict["lora_alpha"] = 2 * wandb.config.lora_r
del config_dict["base_learning_rate"]
@@ -128,6 +139,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
"effective_batch_size": config_dict["effective_batch_size"],
"lr_scale_factor": config_dict["lr_scale_factor"],
"scaled_learning_rate": config_dict["learning_rate"],
"lora_alpha": config_dict["lora_alpha"],
}
)
@@ -135,6 +147,8 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
config_dict.pop("lr_scale_factor")
run_config = TrainingConfig(**config_dict)
run_config = reconstruct_dataclass(TrainingConfig, config_dict)
print(run_config)
train_single_run(run_config, dataset_path)