from transformers import TrainingArguments from trl import SFTTrainer import torch from config import TrainingConfig class CustomTrainer: def __init__(self, config: TrainingConfig): self.config = config self._setup_gpu_tracking() self.wandb = None if self.config.wandb.enabled: try: import wandb self.wandb = wandb except ImportError: print( "Warning: wandb not installed. Run `pip install wandb` to enable logging." ) except Exception as e: print(f"Warning: Failed to initialize wandb: {e}") def _setup_gpu_tracking(self): self.gpu_stats = torch.cuda.get_device_properties(0) self.start_gpu_memory = round( torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3 ) self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3) def _setup_wandb(self): if self.config.wandb.enabled and self.wandb and self.wandb.run is None: # Initialize wandb self.wandb.init( project=self.config.wandb.project, name=self.config.wandb.name, entity=self.config.wandb.entity, tags=self.config.wandb.tags, notes=self.config.wandb.notes, config={ "model": self.config.base_model, "lora_r": self.config.lora_r, "lora_alpha": self.config.lora_alpha, "learning_rate": self.config.learning_rate, "batch_size": self.config.per_device_train_batch_size, "epochs": self.config.num_train_epochs, }, ) return ["wandb"] return None def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer: report_to = self._setup_wandb() training_args = TrainingArguments( output_dir=self.config.output_dir, per_device_train_batch_size=self.config.per_device_train_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, # warmup_ratio=self.config.warmup_ratio, warmup_steps=self.config.warmup_steps, max_grad_norm=self.config.max_grad_norm, num_train_epochs=self.config.num_train_epochs, learning_rate=self.config.learning_rate, weight_decay=self.config.weight_decay, lr_scheduler_type=self.config.lr_scheduler_type, logging_steps=self.config.logging_steps, fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), optim="adamw_8bit", report_to=report_to, save_strategy=self.config.save_strategy, save_steps=self.config.save_steps, save_total_limit=self.config.save_total_limit, fp16_full_eval=self.config.fp16_full_eval, per_device_eval_batch_size=self.config.per_device_eval_batch_size, eval_accumulation_steps=self.config.eval_accumulation_steps, eval_strategy=self.config.eval_strategy, eval_steps=self.config.eval_steps, ) return SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset["train"], eval_dataset=dataset["test"], dataset_text_field="text", max_seq_length=self.config.max_seq_length, dataset_num_proc=self.config.dataset_num_proc, packing=self.config.packing, args=training_args, ) def train_and_log(self, trainer: SFTTrainer) -> dict: print(f"GPU = {self.gpu_stats.name}. Max memory = {self.max_memory} GB.") print(f"{self.start_gpu_memory} GB of memory reserved.") trainer_stats = trainer.train() self._log_training_stats(trainer_stats) return trainer_stats def _log_training_stats(self, trainer_stats): used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3) used_percentage = round(used_memory / self.max_memory * 100, 3) lora_percentage = round(used_memory_for_lora / self.max_memory * 100, 3) print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") print( f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training." ) print(f"Peak reserved memory = {used_memory} GB.") print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") print(f"Peak reserved memory % of max memory = {used_percentage} %.") print( f"Peak reserved memory for training % of max memory = {lora_percentage} %." ) if self.wandb and self.config.wandb.enabled: self.wandb.log( { "training_time_seconds": trainer_stats.metrics["train_runtime"], "training_time_minutes": round( trainer_stats.metrics["train_runtime"] / 60, 2 ), "peak_memory_gb": used_memory, "training_memory_gb": used_memory_for_lora, "peak_memory_percentage": used_percentage, "training_memory_percentage": lora_percentage, } )