import argparse from config import TrainingConfig, WandBConfig from data_loader import DataLoader from model_handler import ModelHandler from trainer import CustomTrainer def parse_args(): parser = argparse.ArgumentParser(description="Train a language model with LoRA") # wandb args wandb_group = parser.add_argument_group("Weights & Biases") wandb_group.add_argument("--wandb_project", type=str, help="WandB project name") wandb_group.add_argument("--wandb_name", type=str, help="WandB run name") wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username") wandb_group.add_argument( "--wandb_tags", type=str, nargs="+", help="WandB tags for the run" ) wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run") # rest parser.add_argument( "--base_model", type=str, default="unsloth/Qwen2.5-7B", help="Base model to fine-tune", ) parser.add_argument( "--dataset", type=str, required=True, help="Path to the training dataset" ) parser.add_argument( "--output_dir", type=str, default="/output/", help="Directory to save the model" ) parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token") return parser.parse_args() def main(): args = parse_args() try: wandb_config = WandBConfig( enabled=args.wandb_project is not None, project=args.wandb_project, name=args.wandb_name, entity=args.wandb_entity, tags=args.wandb_tags, notes=args.wandb_notes, ) config = TrainingConfig( base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config ) model_handler = ModelHandler(config) model, tokenizer = model_handler.setup_model() data_loader = DataLoader(tokenizer, config.data.template) dataset = data_loader.load_dataset(args.dataset) trainer_handler = CustomTrainer(config) trainer = trainer_handler.create_trainer(model, tokenizer, dataset) trainer_stats = trainer_handler.train_and_log(trainer) model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) print("Training completed successfully!") except Exception as e: print(f"Error during training: {e}") raise if __name__ == "__main__": main()