Files
unsloth-train-scripts/main.py
2025-02-15 05:41:51 +06:00

159 lines
5.2 KiB
Python

import argparse
from config import TrainingConfig, WandBConfig
from data_loader import DataLoader
from model_handler import ModelHandler
from trainer import CustomTrainer
from dataclasses import asdict, fields, is_dataclass
from typing import Type
def reconstruct_dataclass(cls: Type, data: dict):
if not is_dataclass(cls):
raise ValueError(f"{cls} is not a dataclass")
fieldtypes = {f.name: f.type for f in fields(cls)}
kwargs = {}
for name, value in data.items():
if name in fieldtypes and is_dataclass(fieldtypes[name]):
kwargs[name] = reconstruct_dataclass(fieldtypes[name], value)
else:
kwargs[name] = value
return cls(**kwargs)
def parse_args():
parser = argparse.ArgumentParser(description="Train a language model with LoRA")
# wandb args
wandb_group = parser.add_argument_group("Weights & Biases")
wandb_group.add_argument("--wandb_project", type=str, help="WandB project name")
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
wandb_group.add_argument(
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
)
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
wandb_group.add_argument(
"--wandb_sweep", action="store_true", help="Enable wandb sweep"
)
# rest
parser.add_argument(
"--base_model",
type=str,
default="unsloth/Qwen2.5-7B",
help="Base model to fine-tune",
)
parser.add_argument(
"--dataset", type=str, required=True, help="Path to the training dataset"
)
parser.add_argument(
"--output_dir", type=str, default="/output/", help="Directory to save the model"
)
parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token")
return parser.parse_args()
def train_single_run(config: TrainingConfig, dataset_path: str):
"""Run a single training session with given config"""
try:
model_handler = ModelHandler(config)
model, tokenizer = model_handler.setup_model()
data_loader = DataLoader(tokenizer, config.data)
dataset = data_loader.load_dataset(dataset_path)
trainer_handler = CustomTrainer(config)
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
trainer_stats = trainer_handler.train_and_log(trainer)
model.save_pretrained(config.output_dir)
tokenizer.save_pretrained(config.output_dir)
print("Training completed successfully!")
return trainer_stats
except Exception as e:
print(f"Error during training: {e}")
raise
def run_sweep(base_config: TrainingConfig, dataset_path: str):
import wandb
sweep_config = {
"method": "bayes",
"metric": {"name": "val_loss", "goal": "minimize"},
"parameters": {
"learning_rate": {
"distribution": "log_uniform_values",
"min": 1e-7,
"max": 1e-5,
},
"lora_r": {"values": [32]},
"lora_alpha": {"values": [64]},
"per_device_train_batch_size": {"values": [32]},
"gradient_accumulation_steps": {"values": [4, 8]},
"num_train_epochs": {"values": [1]},
"warmup_steps": {"values": [10]},
"max_grad_norm": {"values": [0.1, 0.3, 0.5]},
},
"early_terminate": {"type": "hyperband", "min_iter": 100},
}
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
def sweep_train():
with wandb.init() as run:
# Convert base config to dict
config_dict = (
asdict(base_config)
if hasattr(base_config, "__dataclass_fields__")
else base_config.__dict__.copy()
)
# Update with sweep parameters
config_dict.update(wandb.config)
# Set lora_alpha based on lora_r
# config_dict["lora_alpha"] = 2 * config_dict["lora_r"]
# Log effective batch size for monitoring
effective_batch_size = (
config_dict["per_device_train_batch_size"]
* config_dict["gradient_accumulation_steps"]
)
wandb.log({"effective_batch_size": effective_batch_size})
# Create training config and run
run_config = reconstruct_dataclass(TrainingConfig, config_dict)
print(run_config)
train_single_run(run_config, dataset_path)
wandb.agent(sweep_id, function=sweep_train)
def main():
args = parse_args()
wandb_config = WandBConfig(
enabled=bool(args.wandb_project),
project=args.wandb_project,
name=args.wandb_name,
entity=args.wandb_entity,
tags=args.wandb_tags,
notes=args.wandb_notes,
)
base_config = TrainingConfig(
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
)
if args.wandb_sweep:
run_sweep(base_config, args.dataset)
else:
train_single_run(base_config, args.dataset)
if __name__ == "__main__":
main()