chore: up

This commit is contained in:
2025-02-14 21:50:55 +06:00
parent 45f45f4bdb
commit efd62d7c94
4 changed files with 185 additions and 65 deletions

132
main.py
View File

@@ -3,6 +3,7 @@ from config import TrainingConfig, WandBConfig
from data_loader import DataLoader
from model_handler import ModelHandler
from trainer import CustomTrainer
from dataclasses import asdict
def parse_args():
@@ -17,6 +18,9 @@ def parse_args():
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
)
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
wandb_group.add_argument(
"--wandb_sweep", action="store_true", help="Enable wandb sweep"
)
# rest
parser.add_argument(
"--base_model",
@@ -34,42 +38,130 @@ def parse_args():
return parser.parse_args()
def main():
args = parse_args()
def train_single_run(config: TrainingConfig, dataset_path: str):
"""Run a single training session with given config"""
try:
wandb_config = WandBConfig(
enabled=args.wandb_project is not None,
project=args.wandb_project,
name=args.wandb_name,
entity=args.wandb_entity,
tags=args.wandb_tags,
notes=args.wandb_notes,
)
config = TrainingConfig(
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
)
model_handler = ModelHandler(config)
model, tokenizer = model_handler.setup_model()
data_loader = DataLoader(tokenizer, config.data.template)
dataset = data_loader.load_dataset(args.dataset)
data_loader = DataLoader(tokenizer, config.data)
dataset = data_loader.load_dataset(dataset_path)
trainer_handler = CustomTrainer(config)
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
trainer_stats = trainer_handler.train_and_log(trainer)
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
model.save_pretrained(config.output_dir)
tokenizer.save_pretrained(config.output_dir)
print("Training completed successfully!")
return trainer_stats
except Exception as e:
print(f"Error during training: {e}")
raise
def run_sweep(base_config: TrainingConfig, dataset_path: str):
import wandb
sweep_config = {
"method": "bayes",
"metric": {"name": "val_loss", "goal": "minimize"},
"parameters": {
"base_learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-5,
"max": 1e-3,
},
"lora_r": {"values": [8, 16, 32, 64]},
"lora_alpha": {"values": [16, 32, 64, 128]},
"per_device_train_batch_size": {"values": [16]},
"gradient_accumulation_steps": {"values": [2, 4, 8]},
"warmup_ratio": {
"distribution": "uniform",
"min": 0.006,
"max": 0.015,
},
"num_train_epochs": {"values": [1]},
},
"early_terminate": {"type": "hyperband", "min_iter": 100},
}
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
def compute_scaled_config(config_dict: dict) -> dict:
base_batch_size = 16 # ref point
effective_batch_size = (
config_dict["per_device_train_batch_size"]
* config_dict["gradient_accumulation_steps"]
)
# apply square root scaling for Adam
lr_scale_factor = (effective_batch_size / base_batch_size) ** 0.5
# scale the learning rate
config_dict["learning_rate"] = (
config_dict["base_learning_rate"] * lr_scale_factor
)
config_dict["effective_batch_size"] = effective_batch_size
config_dict["lr_scale_factor"] = lr_scale_factor
del config_dict["base_learning_rate"]
return config_dict
def sweep_train():
with wandb.init() as run:
config_dict = (
asdict(base_config)
if hasattr(base_config, "__dataclass_fields__")
else base_config.__dict__.copy()
)
config_dict.update(wandb.config)
config_dict = compute_scaled_config(config_dict)
wandb.log(
{
"effective_batch_size": config_dict["effective_batch_size"],
"lr_scale_factor": config_dict["lr_scale_factor"],
"scaled_learning_rate": config_dict["learning_rate"],
}
)
config_dict.pop("effective_batch_size")
config_dict.pop("lr_scale_factor")
run_config = TrainingConfig(**config_dict)
train_single_run(run_config, dataset_path)
wandb.agent(sweep_id, function=sweep_train)
def main():
args = parse_args()
wandb_config = WandBConfig(
enabled=bool(args.wandb_project),
project=args.wandb_project,
name=args.wandb_name,
entity=args.wandb_entity,
tags=args.wandb_tags,
notes=args.wandb_notes,
)
base_config = TrainingConfig(
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
)
if args.wandb_sweep:
run_sweep(base_config, args.dataset)
else:
train_single_run(base_config, args.dataset)
if __name__ == "__main__":
main()