chore: ss
This commit is contained in:
32
main.py
32
main.py
@@ -3,7 +3,22 @@ from config import TrainingConfig, WandBConfig
|
||||
from data_loader import DataLoader
|
||||
from model_handler import ModelHandler
|
||||
from trainer import CustomTrainer
|
||||
from dataclasses import asdict
|
||||
from dataclasses import asdict, fields, is_dataclass
|
||||
from typing import Type
|
||||
|
||||
|
||||
def reconstruct_dataclass(cls: Type, data: dict):
|
||||
if not is_dataclass(cls):
|
||||
raise ValueError(f"{cls} is not a dataclass")
|
||||
|
||||
fieldtypes = {f.name: f.type for f in fields(cls)}
|
||||
kwargs = {}
|
||||
for name, value in data.items():
|
||||
if name in fieldtypes and is_dataclass(fieldtypes[name]):
|
||||
kwargs[name] = reconstruct_dataclass(fieldtypes[name], value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -74,15 +89,10 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||
"min": 1e-5,
|
||||
"max": 1e-3,
|
||||
},
|
||||
"lora_r": {"values": [8, 16, 32, 64]},
|
||||
"lora_alpha": {"values": [16, 32, 64, 128]},
|
||||
"lora_r": {"values": [32, 64]},
|
||||
# "lora_alpha": {"values": [16, 32, 64, 128]},
|
||||
"per_device_train_batch_size": {"values": [16]},
|
||||
"gradient_accumulation_steps": {"values": [2, 4, 8]},
|
||||
"warmup_ratio": {
|
||||
"distribution": "uniform",
|
||||
"min": 0.006,
|
||||
"max": 0.015,
|
||||
},
|
||||
"num_train_epochs": {"values": [1]},
|
||||
},
|
||||
"early_terminate": {"type": "hyperband", "min_iter": 100},
|
||||
@@ -91,7 +101,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
|
||||
|
||||
def compute_scaled_config(config_dict: dict) -> dict:
|
||||
base_batch_size = 16 # ref point
|
||||
base_batch_size = config_dict["per_device_train_batch_size"] # 16
|
||||
effective_batch_size = (
|
||||
config_dict["per_device_train_batch_size"]
|
||||
* config_dict["gradient_accumulation_steps"]
|
||||
@@ -107,6 +117,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||
|
||||
config_dict["effective_batch_size"] = effective_batch_size
|
||||
config_dict["lr_scale_factor"] = lr_scale_factor
|
||||
config_dict["lora_alpha"] = 2 * wandb.config.lora_r
|
||||
|
||||
del config_dict["base_learning_rate"]
|
||||
|
||||
@@ -128,6 +139,7 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||
"effective_batch_size": config_dict["effective_batch_size"],
|
||||
"lr_scale_factor": config_dict["lr_scale_factor"],
|
||||
"scaled_learning_rate": config_dict["learning_rate"],
|
||||
"lora_alpha": config_dict["lora_alpha"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -135,6 +147,8 @@ def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||
config_dict.pop("lr_scale_factor")
|
||||
|
||||
run_config = TrainingConfig(**config_dict)
|
||||
run_config = reconstruct_dataclass(TrainingConfig, config_dict)
|
||||
print(run_config)
|
||||
|
||||
train_single_run(run_config, dataset_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user