import argparse from config import TrainingConfig, WandBConfig from data_loader import DataLoader from model_handler import ModelHandler from trainer import CustomTrainer from dataclasses import asdict, fields, is_dataclass from typing import Type def reconstruct_dataclass(cls: Type, data: dict): if not is_dataclass(cls): raise ValueError(f"{cls} is not a dataclass") fieldtypes = {f.name: f.type for f in fields(cls)} kwargs = {} for name, value in data.items(): if name in fieldtypes and is_dataclass(fieldtypes[name]): kwargs[name] = reconstruct_dataclass(fieldtypes[name], value) else: kwargs[name] = value return cls(**kwargs) def parse_args(): parser = argparse.ArgumentParser(description="Train a language model with LoRA") # wandb args wandb_group = parser.add_argument_group("Weights & Biases") wandb_group.add_argument("--wandb_project", type=str, help="WandB project name") wandb_group.add_argument("--wandb_name", type=str, help="WandB run name") wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username") wandb_group.add_argument( "--wandb_tags", type=str, nargs="+", help="WandB tags for the run" ) wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run") wandb_group.add_argument( "--wandb_sweep", action="store_true", help="Enable wandb sweep" ) # rest parser.add_argument( "--base_model", type=str, default="unsloth/Qwen2.5-7B", help="Base model to fine-tune", ) parser.add_argument( "--dataset", type=str, required=True, help="Path to the training dataset" ) parser.add_argument( "--output_dir", type=str, default="/output/", help="Directory to save the model" ) parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token") return parser.parse_args() def train_single_run(config: TrainingConfig, dataset_path: str): """Run a single training session with given config""" try: model_handler = ModelHandler(config) model, tokenizer = model_handler.setup_model() data_loader = DataLoader(tokenizer, config.data) dataset = data_loader.load_dataset(dataset_path) trainer_handler = CustomTrainer(config) trainer = trainer_handler.create_trainer(model, tokenizer, dataset) trainer_stats = trainer_handler.train_and_log(trainer) model.save_pretrained(config.output_dir) tokenizer.save_pretrained(config.output_dir) print("Training completed successfully!") return trainer_stats except Exception as e: print(f"Error during training: {e}") raise def run_sweep(base_config: TrainingConfig, dataset_path: str): import wandb sweep_config = { "method": "bayes", "metric": {"name": "val_loss", "goal": "minimize"}, "parameters": { "base_learning_rate": { "distribution": "log_uniform_values", "min": 1e-5, "max": 1e-3, }, "lora_r": {"values": [32, 64]}, # "lora_alpha": {"values": [16, 32, 64, 128]}, "per_device_train_batch_size": {"values": [16]}, "gradient_accumulation_steps": {"values": [2, 4, 8]}, "num_train_epochs": {"values": [1]}, }, "early_terminate": {"type": "hyperband", "min_iter": 100}, } sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project) def compute_scaled_config(config_dict: dict) -> dict: base_batch_size = config_dict["per_device_train_batch_size"] # 16 effective_batch_size = ( config_dict["per_device_train_batch_size"] * config_dict["gradient_accumulation_steps"] ) # apply square root scaling for Adam lr_scale_factor = (effective_batch_size / base_batch_size) ** 0.5 # scale the learning rate config_dict["learning_rate"] = ( config_dict["base_learning_rate"] * lr_scale_factor ) config_dict["effective_batch_size"] = effective_batch_size config_dict["lr_scale_factor"] = lr_scale_factor config_dict["lora_alpha"] = 2 * wandb.config.lora_r del config_dict["base_learning_rate"] return config_dict def sweep_train(): with wandb.init() as run: config_dict = ( asdict(base_config) if hasattr(base_config, "__dataclass_fields__") else base_config.__dict__.copy() ) config_dict.update(wandb.config) config_dict = compute_scaled_config(config_dict) wandb.log( { "effective_batch_size": config_dict["effective_batch_size"], "lr_scale_factor": config_dict["lr_scale_factor"], "scaled_learning_rate": config_dict["learning_rate"], "lora_alpha": config_dict["lora_alpha"], } ) config_dict.pop("effective_batch_size") config_dict.pop("lr_scale_factor") run_config = TrainingConfig(**config_dict) run_config = reconstruct_dataclass(TrainingConfig, config_dict) print(run_config) train_single_run(run_config, dataset_path) wandb.agent(sweep_id, function=sweep_train) def main(): args = parse_args() wandb_config = WandBConfig( enabled=bool(args.wandb_project), project=args.wandb_project, name=args.wandb_name, entity=args.wandb_entity, tags=args.wandb_tags, notes=args.wandb_notes, ) base_config = TrainingConfig( base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config ) if args.wandb_sweep: run_sweep(base_config, args.dataset) else: train_single_run(base_config, args.dataset) if __name__ == "__main__": main()