chore: trainer
This commit is contained in:
77
main.py
Normal file
77
main.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
from config import TrainingConfig, WandBConfig
|
||||
from data_loader import DataLoader
|
||||
from model_handler import ModelHandler
|
||||
from trainer import CustomTrainer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train a language model with LoRA")
|
||||
|
||||
# wandb args
|
||||
wandb_group = parser.add_argument_group("Weights & Biases")
|
||||
wandb_group.add_argument(
|
||||
"--wandb_project", type=str, default="lora-training", help="WandB project name"
|
||||
)
|
||||
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
|
||||
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
|
||||
wandb_group.add_argument(
|
||||
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
|
||||
)
|
||||
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
|
||||
# rest
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
default="unsloth/Qwen2.5-7B",
|
||||
help="Base model to fine-tune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, required=True, help="Path to the training dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="/output/", help="Directory to save the model"
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
wandb_config = WandBConfig(
|
||||
enabled=args.wandb_project is not None,
|
||||
project=args.wandb_project or "lora-training",
|
||||
name=args.wandb_name,
|
||||
entity=args.wandb_entity,
|
||||
tags=args.wandb_tags,
|
||||
notes=args.wandb_notes,
|
||||
)
|
||||
|
||||
config = TrainingConfig(
|
||||
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
|
||||
)
|
||||
|
||||
model_handler = ModelHandler(config)
|
||||
model, tokenizer = model_handler.setup_model()
|
||||
|
||||
data_loader = DataLoader(tokenizer, config.data.template)
|
||||
dataset = data_loader.load_dataset(args.dataset)
|
||||
|
||||
trainer_handler = CustomTrainer(config)
|
||||
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
|
||||
trainer_stats = trainer_handler.train_and_log(trainer)
|
||||
|
||||
model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
print("Training completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during training: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user