chore: trainer
This commit is contained in:
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"python.languageServer": "None"
|
||||||
|
}
|
||||||
63
README.md
63
README.md
@@ -0,0 +1,63 @@
|
|||||||
|
# Unsloth LoRA scripts
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://git.hye.su/mira/unsloth-train-scripts.git
|
||||||
|
cd unsloth-train-scripts
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install pytorch and unsloth:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -
|
||||||
|
pip install gdown # Optional: Only needed for Google Drive datasets
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
unsloth-lora-training/
|
||||||
|
├── config.py # Configuration settings
|
||||||
|
├── data_loader.py # Dataset loading and processing
|
||||||
|
├── model_handler.py # Model initialization and PEFT setup
|
||||||
|
├── trainer.py # Training loop and metrics
|
||||||
|
├── main.py # Main training script
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
All configuration settings are managed in `config.py`. The main configuration class is `TrainingConfig`
|
||||||
|
|
||||||
|
To modify the default configuration, edit the `TrainingConfig` class in `config.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
base_model: str = "unsloth/Qwen2.5-7B"
|
||||||
|
max_seq_length: int = 16384
|
||||||
|
# ... modify other parameters as needed
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py \
|
||||||
|
--base_model mistralai/Mistral-7B-v0.1 \
|
||||||
|
--dataset path/to/your/dataset.json \
|
||||||
|
--output_dir ./custom_output
|
||||||
|
--hub_token "secret"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Google Drive Dataset
|
||||||
|
|
||||||
|
Train using a dataset stored on Google Drive:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py \
|
||||||
|
--dataset https://drive.google.com/file/d/your_file_id/view \
|
||||||
|
--output_dir ./drive_output
|
||||||
|
```
|
||||||
|
|||||||
73
config.py
Normal file
73
config.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataConfig:
|
||||||
|
# Default configuration
|
||||||
|
template: str = """Translate this Chinese text to English:
|
||||||
|
{}
|
||||||
|
===
|
||||||
|
Translation:
|
||||||
|
{}"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WandBConfig:
|
||||||
|
enabled: bool = True
|
||||||
|
project: str = "lora-training"
|
||||||
|
name: str | None = None
|
||||||
|
entity: str | None = None
|
||||||
|
tags: list[str] = []
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
wandb: WandBConfig = WandBConfig()
|
||||||
|
data: DataConfig = DataConfig()
|
||||||
|
|
||||||
|
# model
|
||||||
|
base_model: str = "unsloth/Qwen2.5-7B"
|
||||||
|
max_seq_length: int = 6144
|
||||||
|
dtype: str | None = None
|
||||||
|
load_in_4bit: bool = True
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
lora_r: int = 16
|
||||||
|
lora_alpha: int = 16
|
||||||
|
lora_dropout: float = 0
|
||||||
|
target_modules: list[str] = []
|
||||||
|
use_gradient_checkpointing: str = "unsloth"
|
||||||
|
random_state: int = 3407
|
||||||
|
use_rslora: bool = False
|
||||||
|
loftq_config: dict | None = None
|
||||||
|
|
||||||
|
# training args
|
||||||
|
per_device_train_batch_size: int = 32
|
||||||
|
gradient_accumulation_steps: int = 1
|
||||||
|
warmup_ratio: float = 0.05
|
||||||
|
max_grad_norm: float = 1.0
|
||||||
|
num_train_epochs: float = 0.5
|
||||||
|
learning_rate: float = 3e-5
|
||||||
|
weight_decay: float = 0.05
|
||||||
|
lr_scheduler_type: str = "linear"
|
||||||
|
logging_steps: int = 5
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
dataset_num_proc: int = 2
|
||||||
|
packing: bool = False
|
||||||
|
|
||||||
|
# output
|
||||||
|
output_dir: str = "/output/"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.target_modules:
|
||||||
|
self.target_modules = [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
50
data_loader.py
Normal file
50
data_loader.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self._template = template
|
||||||
|
|
||||||
|
def load_dataset(self, path: str) -> Dataset:
|
||||||
|
"""Load dataset from local path or Google Drive"""
|
||||||
|
if "drive.google.com" in str(path):
|
||||||
|
try:
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
local_path = "downloaded_dataset.json"
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
gdown.download(url=path, output=local_path, fuzzy=True)
|
||||||
|
dataset_path = local_path
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install gdown: pip install gdown")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error downloading from Google Drive: {e}")
|
||||||
|
else:
|
||||||
|
dataset_path = path
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||||
|
return self.process_dataset(dataset)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error loading dataset: {e}")
|
||||||
|
|
||||||
|
def process_dataset(self, dataset: Dataset) -> Dataset:
|
||||||
|
"""Process and format the dataset"""
|
||||||
|
|
||||||
|
def formatting_func(examples: dict[str, Any]) -> dict[str, list[str]]:
|
||||||
|
inputs: list[str] = examples["input"]
|
||||||
|
outputs: list[str] = examples["output"]
|
||||||
|
texts: list[str] = []
|
||||||
|
for input, output in zip(inputs, outputs):
|
||||||
|
text = (
|
||||||
|
self._template.format(input=input, output=output)
|
||||||
|
+ self.tokenizer.eos_token
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
return {"text": texts}
|
||||||
|
|
||||||
|
return dataset.map(formatting_func, batched=True)
|
||||||
77
main.py
Normal file
77
main.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import argparse
|
||||||
|
from config import TrainingConfig, WandBConfig
|
||||||
|
from data_loader import DataLoader
|
||||||
|
from model_handler import ModelHandler
|
||||||
|
from trainer import CustomTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Train a language model with LoRA")
|
||||||
|
|
||||||
|
# wandb args
|
||||||
|
wandb_group = parser.add_argument_group("Weights & Biases")
|
||||||
|
wandb_group.add_argument(
|
||||||
|
"--wandb_project", type=str, default="lora-training", help="WandB project name"
|
||||||
|
)
|
||||||
|
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
|
||||||
|
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
|
||||||
|
wandb_group.add_argument(
|
||||||
|
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
|
||||||
|
)
|
||||||
|
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
|
||||||
|
# rest
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model",
|
||||||
|
type=str,
|
||||||
|
default="unsloth/Qwen2.5-7B",
|
||||||
|
help="Base model to fine-tune",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset", type=str, required=True, help="Path to the training dataset"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", type=str, default="/output/", help="Directory to save the model"
|
||||||
|
)
|
||||||
|
parser.add_argument("--hub_token", type=str, help="Hugging Face Hub token")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
wandb_config = WandBConfig(
|
||||||
|
enabled=args.wandb_project is not None,
|
||||||
|
project=args.wandb_project or "lora-training",
|
||||||
|
name=args.wandb_name,
|
||||||
|
entity=args.wandb_entity,
|
||||||
|
tags=args.wandb_tags,
|
||||||
|
notes=args.wandb_notes,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = TrainingConfig(
|
||||||
|
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
|
||||||
|
)
|
||||||
|
|
||||||
|
model_handler = ModelHandler(config)
|
||||||
|
model, tokenizer = model_handler.setup_model()
|
||||||
|
|
||||||
|
data_loader = DataLoader(tokenizer, config.data.template)
|
||||||
|
dataset = data_loader.load_dataset(args.dataset)
|
||||||
|
|
||||||
|
trainer_handler = CustomTrainer(config)
|
||||||
|
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
|
||||||
|
trainer_stats = trainer_handler.train_and_log(trainer)
|
||||||
|
|
||||||
|
model.save_pretrained(args.output_dir)
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
print("Training completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during training: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
42
model_handler.py
Normal file
42
model_handler.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
from config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModelHandler:
|
||||||
|
def __init__(self, config: TrainingConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def setup_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||||
|
try:
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=self.config.base_model,
|
||||||
|
max_seq_length=self.config.max_seq_length,
|
||||||
|
dtype=self.config.dtype,
|
||||||
|
load_in_4bit=self.config.load_in_4bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self._setup_peft(model)
|
||||||
|
return model, tokenizer
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error setting up model: {e}")
|
||||||
|
|
||||||
|
def _setup_peft(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||||
|
"""Setup PEFT config for the model"""
|
||||||
|
try:
|
||||||
|
return FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=self.config.lora_r,
|
||||||
|
target_modules=self.config.target_modules,
|
||||||
|
lora_alpha=self.config.lora_alpha,
|
||||||
|
lora_dropout=self.config.lora_dropout,
|
||||||
|
bias="none",
|
||||||
|
use_gradient_checkpointing=self.config.use_gradient_checkpointing,
|
||||||
|
random_state=self.config.random_state,
|
||||||
|
max_seq_length=self.config.max_seq_length,
|
||||||
|
use_rslora=self.config.use_rslora,
|
||||||
|
loftq_config=self.config.loftq_config,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error setting up PEFT: {e}")
|
||||||
123
trainer.py
Normal file
123
trainer.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from transformers import TrainingArguments
|
||||||
|
from trl import SFTTrainer
|
||||||
|
import torch
|
||||||
|
from config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTrainer:
|
||||||
|
def __init__(self, config: TrainingConfig):
|
||||||
|
self.config = config
|
||||||
|
self._setup_gpu_tracking()
|
||||||
|
|
||||||
|
def _setup_gpu_tracking(self):
|
||||||
|
self.gpu_stats = torch.cuda.get_device_properties(0)
|
||||||
|
self.start_gpu_memory = round(
|
||||||
|
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3
|
||||||
|
)
|
||||||
|
self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||||
|
|
||||||
|
def _setup_wandb(self):
|
||||||
|
if self.config.wandb.enabled:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
# Initialize wandb
|
||||||
|
wandb.init(
|
||||||
|
project=self.config.wandb.project,
|
||||||
|
name=self.config.wandb.name,
|
||||||
|
entity=self.config.wandb.entity,
|
||||||
|
tags=self.config.wandb.tags,
|
||||||
|
notes=self.config.wandb.notes,
|
||||||
|
config={
|
||||||
|
"model": self.config.base_model,
|
||||||
|
"lora_r": self.config.lora_r,
|
||||||
|
"lora_alpha": self.config.lora_alpha,
|
||||||
|
"learning_rate": self.config.learning_rate,
|
||||||
|
"batch_size": self.config.per_device_train_batch_size,
|
||||||
|
"epochs": self.config.num_train_epochs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ["wandb"]
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to initialize wandb: {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer:
|
||||||
|
report_to = self._setup_wandb()
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=self.config.output_dir,
|
||||||
|
per_device_train_batch_size=self.config.per_device_train_batch_size,
|
||||||
|
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
||||||
|
warmup_ratio=self.config.warmup_ratio,
|
||||||
|
max_grad_norm=self.config.max_grad_norm,
|
||||||
|
num_train_epochs=self.config.num_train_epochs,
|
||||||
|
learning_rate=self.config.learning_rate,
|
||||||
|
weight_decay=self.config.weight_decay,
|
||||||
|
lr_scheduler_type=self.config.lr_scheduler_type,
|
||||||
|
logging_steps=self.config.logging_steps,
|
||||||
|
fp16=not torch.cuda.is_bf16_supported(),
|
||||||
|
bf16=torch.cuda.is_bf16_supported(),
|
||||||
|
optim="adamw_8bit",
|
||||||
|
report_to=report_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_dataset=dataset,
|
||||||
|
dataset_text_field="text",
|
||||||
|
max_seq_length=self.config.max_seq_length,
|
||||||
|
dataset_num_proc=self.config.dataset_num_proc,
|
||||||
|
packing=self.config.packing,
|
||||||
|
args=training_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_and_log(self, trainer: SFTTrainer) -> dict:
|
||||||
|
print(f"GPU = {self.gpu_stats.name}. Max memory = {self.max_memory} GB.")
|
||||||
|
print(f"{self.start_gpu_memory} GB of memory reserved.")
|
||||||
|
|
||||||
|
trainer_stats = trainer.train()
|
||||||
|
self._log_training_stats(trainer_stats)
|
||||||
|
return trainer_stats
|
||||||
|
|
||||||
|
def _log_training_stats(self, trainer_stats):
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError:
|
||||||
|
wandb = None
|
||||||
|
|
||||||
|
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||||
|
used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3)
|
||||||
|
used_percentage = round(used_memory / self.max_memory * 100, 3)
|
||||||
|
lora_percentage = round(used_memory_for_lora / self.max_memory * 100, 3)
|
||||||
|
|
||||||
|
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
||||||
|
print(
|
||||||
|
f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training."
|
||||||
|
)
|
||||||
|
print(f"Peak reserved memory = {used_memory} GB.")
|
||||||
|
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
||||||
|
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
||||||
|
print(
|
||||||
|
f"Peak reserved memory for training % of max memory = {lora_percentage} %."
|
||||||
|
)
|
||||||
|
|
||||||
|
if wandb and self.config.wandb.enabled:
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"training_time_seconds": trainer_stats.metrics["train_runtime"],
|
||||||
|
"training_time_minutes": round(
|
||||||
|
trainer_stats.metrics["train_runtime"] / 60, 2
|
||||||
|
),
|
||||||
|
"peak_memory_gb": used_memory,
|
||||||
|
"training_memory_gb": used_memory_for_lora,
|
||||||
|
"peak_memory_percentage": used_percentage,
|
||||||
|
"training_memory_percentage": lora_percentage,
|
||||||
|
}
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user