159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
import argparse
|
|
from config import TrainingConfig, WandBConfig
|
|
from data_loader import DataLoader
|
|
from model_handler import ModelHandler
|
|
from trainer import CustomTrainer
|
|
from dataclasses import asdict, fields, is_dataclass
|
|
from typing import Type
|
|
|
|
|
|
def reconstruct_dataclass(cls: Type, data: dict):
|
|
if not is_dataclass(cls):
|
|
raise ValueError(f"{cls} is not a dataclass")
|
|
|
|
fieldtypes = {f.name: f.type for f in fields(cls)}
|
|
kwargs = {}
|
|
for name, value in data.items():
|
|
if name in fieldtypes and is_dataclass(fieldtypes[name]):
|
|
kwargs[name] = reconstruct_dataclass(fieldtypes[name], value)
|
|
else:
|
|
kwargs[name] = value
|
|
return cls(**kwargs)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Train a language model with LoRA")
|
|
|
|
# wandb args
|
|
wandb_group = parser.add_argument_group("Weights & Biases")
|
|
wandb_group.add_argument("--wandb_project", type=str, help="WandB project name")
|
|
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
|
|
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
|
|
wandb_group.add_argument(
|
|
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
|
|
)
|
|
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
|
|
wandb_group.add_argument(
|
|
"--wandb_sweep", action="store_true", help="Enable wandb sweep"
|
|
)
|
|
# rest
|
|
parser.add_argument(
|
|
"--base_model",
|
|
type=str,
|
|
default="unsloth/Qwen2.5-7B",
|
|
help="Base model to fine-tune",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset", type=str, required=True, help="Path to the training dataset"
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, default="/output/", help="Directory to save the model"
|
|
)
|
|
parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token")
|
|
return parser.parse_args()
|
|
|
|
|
|
def train_single_run(config: TrainingConfig, dataset_path: str):
|
|
"""Run a single training session with given config"""
|
|
try:
|
|
model_handler = ModelHandler(config)
|
|
model, tokenizer = model_handler.setup_model()
|
|
|
|
data_loader = DataLoader(tokenizer, config.data)
|
|
dataset = data_loader.load_dataset(dataset_path)
|
|
|
|
trainer_handler = CustomTrainer(config)
|
|
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
|
|
trainer_stats = trainer_handler.train_and_log(trainer)
|
|
|
|
model.save_pretrained(config.output_dir)
|
|
tokenizer.save_pretrained(config.output_dir)
|
|
|
|
print("Training completed successfully!")
|
|
return trainer_stats
|
|
|
|
except Exception as e:
|
|
print(f"Error during training: {e}")
|
|
raise
|
|
|
|
|
|
def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
|
import wandb
|
|
|
|
sweep_config = {
|
|
"method": "bayes",
|
|
"metric": {"name": "val_loss", "goal": "minimize"},
|
|
"parameters": {
|
|
"learning_rate": {
|
|
"distribution": "log_uniform_values",
|
|
"min": 1e-7,
|
|
"max": 1e-5,
|
|
},
|
|
"lora_r": {"values": [32]},
|
|
"lora_alpha": {"values": [64]},
|
|
"per_device_train_batch_size": {"values": [32]},
|
|
"gradient_accumulation_steps": {"values": [4, 8]},
|
|
"num_train_epochs": {"values": [1]},
|
|
"warmup_steps": {"values": [10]},
|
|
"max_grad_norm": {"values": [0.1, 0.3, 0.5]},
|
|
},
|
|
"early_terminate": {"type": "hyperband", "min_iter": 100},
|
|
}
|
|
|
|
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
|
|
|
|
def sweep_train():
|
|
with wandb.init() as run:
|
|
# Convert base config to dict
|
|
config_dict = (
|
|
asdict(base_config)
|
|
if hasattr(base_config, "__dataclass_fields__")
|
|
else base_config.__dict__.copy()
|
|
)
|
|
|
|
# Update with sweep parameters
|
|
config_dict.update(wandb.config)
|
|
|
|
# Set lora_alpha based on lora_r
|
|
# config_dict["lora_alpha"] = 2 * config_dict["lora_r"]
|
|
|
|
# Log effective batch size for monitoring
|
|
effective_batch_size = (
|
|
config_dict["per_device_train_batch_size"]
|
|
* config_dict["gradient_accumulation_steps"]
|
|
)
|
|
wandb.log({"effective_batch_size": effective_batch_size})
|
|
|
|
# Create training config and run
|
|
run_config = reconstruct_dataclass(TrainingConfig, config_dict)
|
|
print(run_config)
|
|
train_single_run(run_config, dataset_path)
|
|
|
|
wandb.agent(sweep_id, function=sweep_train)
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
wandb_config = WandBConfig(
|
|
enabled=bool(args.wandb_project),
|
|
project=args.wandb_project,
|
|
name=args.wandb_name,
|
|
entity=args.wandb_entity,
|
|
tags=args.wandb_tags,
|
|
notes=args.wandb_notes,
|
|
)
|
|
|
|
base_config = TrainingConfig(
|
|
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
|
|
)
|
|
|
|
if args.wandb_sweep:
|
|
run_sweep(base_config, args.dataset)
|
|
else:
|
|
train_single_run(base_config, args.dataset)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|