130 lines
5.4 KiB
Python
130 lines
5.4 KiB
Python
from transformers import TrainingArguments
|
|
from trl import SFTTrainer
|
|
import torch
|
|
from config import TrainingConfig
|
|
|
|
|
|
class CustomTrainer:
|
|
def __init__(self, config: TrainingConfig):
|
|
self.config = config
|
|
self._setup_gpu_tracking()
|
|
self.wandb = None
|
|
if self.config.wandb.enabled:
|
|
try:
|
|
import wandb
|
|
|
|
self.wandb = wandb
|
|
except ImportError:
|
|
print(
|
|
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
|
|
)
|
|
except Exception as e:
|
|
print(f"Warning: Failed to initialize wandb: {e}")
|
|
|
|
def _setup_gpu_tracking(self):
|
|
self.gpu_stats = torch.cuda.get_device_properties(0)
|
|
self.start_gpu_memory = round(
|
|
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3
|
|
)
|
|
self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
|
|
|
def _setup_wandb(self):
|
|
if self.config.wandb.enabled and self.wandb and self.wandb.run is None:
|
|
# Initialize wandb
|
|
self.wandb.init(
|
|
project=self.config.wandb.project,
|
|
name=self.config.wandb.name,
|
|
entity=self.config.wandb.entity,
|
|
tags=self.config.wandb.tags,
|
|
notes=self.config.wandb.notes,
|
|
config={
|
|
"model": self.config.base_model,
|
|
"lora_r": self.config.lora_r,
|
|
"lora_alpha": self.config.lora_alpha,
|
|
"learning_rate": self.config.learning_rate,
|
|
"batch_size": self.config.per_device_train_batch_size,
|
|
"epochs": self.config.num_train_epochs,
|
|
},
|
|
)
|
|
return ["wandb"]
|
|
return None
|
|
|
|
def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer:
|
|
report_to = self._setup_wandb()
|
|
training_args = TrainingArguments(
|
|
output_dir=self.config.output_dir,
|
|
per_device_train_batch_size=self.config.per_device_train_batch_size,
|
|
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
|
# warmup_ratio=self.config.warmup_ratio,
|
|
warmup_steps=self.config.warmup_steps,
|
|
max_grad_norm=self.config.max_grad_norm,
|
|
num_train_epochs=self.config.num_train_epochs,
|
|
learning_rate=self.config.learning_rate,
|
|
weight_decay=self.config.weight_decay,
|
|
lr_scheduler_type=self.config.lr_scheduler_type,
|
|
logging_steps=self.config.logging_steps,
|
|
fp16=not torch.cuda.is_bf16_supported(),
|
|
bf16=torch.cuda.is_bf16_supported(),
|
|
optim="adamw_8bit",
|
|
report_to=report_to,
|
|
save_strategy=self.config.save_strategy,
|
|
save_steps=self.config.save_steps,
|
|
save_total_limit=self.config.save_total_limit,
|
|
fp16_full_eval=self.config.fp16_full_eval,
|
|
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
|
|
eval_accumulation_steps=self.config.eval_accumulation_steps,
|
|
eval_strategy=self.config.eval_strategy,
|
|
eval_steps=self.config.eval_steps,
|
|
)
|
|
|
|
return SFTTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
train_dataset=dataset["train"],
|
|
eval_dataset=dataset["test"],
|
|
dataset_text_field="text",
|
|
max_seq_length=self.config.max_seq_length,
|
|
dataset_num_proc=self.config.dataset_num_proc,
|
|
packing=self.config.packing,
|
|
args=training_args,
|
|
)
|
|
|
|
def train_and_log(self, trainer: SFTTrainer) -> dict:
|
|
print(f"GPU = {self.gpu_stats.name}. Max memory = {self.max_memory} GB.")
|
|
print(f"{self.start_gpu_memory} GB of memory reserved.")
|
|
|
|
trainer_stats = trainer.train()
|
|
self._log_training_stats(trainer_stats)
|
|
return trainer_stats
|
|
|
|
def _log_training_stats(self, trainer_stats):
|
|
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
|
used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3)
|
|
used_percentage = round(used_memory / self.max_memory * 100, 3)
|
|
lora_percentage = round(used_memory_for_lora / self.max_memory * 100, 3)
|
|
|
|
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
|
|
print(
|
|
f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training."
|
|
)
|
|
print(f"Peak reserved memory = {used_memory} GB.")
|
|
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
|
|
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
|
|
print(
|
|
f"Peak reserved memory for training % of max memory = {lora_percentage} %."
|
|
)
|
|
|
|
if self.wandb and self.config.wandb.enabled:
|
|
self.wandb.log(
|
|
{
|
|
"training_time_seconds": trainer_stats.metrics["train_runtime"],
|
|
"training_time_minutes": round(
|
|
trainer_stats.metrics["train_runtime"] / 60, 2
|
|
),
|
|
"peak_memory_gb": used_memory,
|
|
"training_memory_gb": used_memory_for_lora,
|
|
"peak_memory_percentage": used_percentage,
|
|
"training_memory_percentage": lora_percentage,
|
|
}
|
|
)
|