chore: up

This commit is contained in:
2025-02-14 21:50:55 +06:00
parent 45f45f4bdb
commit efd62d7c94
4 changed files with 185 additions and 65 deletions

View File

@@ -8,6 +8,18 @@ class CustomTrainer:
def __init__(self, config: TrainingConfig):
self.config = config
self._setup_gpu_tracking()
self.wandb = None
if self.config.wandb.enabled:
try:
import wandb
self.wandb = wandb
except ImportError:
print(
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
)
except Exception as e:
print(f"Warning: Failed to initialize wandb: {e}")
def _setup_gpu_tracking(self):
self.gpu_stats = torch.cuda.get_device_properties(0)
@@ -17,35 +29,24 @@ class CustomTrainer:
self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
def _setup_wandb(self):
if self.config.wandb.enabled:
try:
import wandb
# Initialize wandb
wandb.init(
project=self.config.wandb.project,
name=self.config.wandb.name,
entity=self.config.wandb.entity,
tags=self.config.wandb.tags,
notes=self.config.wandb.notes,
config={
"model": self.config.base_model,
"lora_r": self.config.lora_r,
"lora_alpha": self.config.lora_alpha,
"learning_rate": self.config.learning_rate,
"batch_size": self.config.per_device_train_batch_size,
"epochs": self.config.num_train_epochs,
},
)
return ["wandb"]
except ImportError:
print(
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
)
return None
except Exception as e:
print(f"Warning: Failed to initialize wandb: {e}")
return None
if self.config.wandb.enabled and self.wandb and self.wandb.run is None:
# Initialize wandb
self.wandb.init(
project=self.config.wandb.project,
name=self.config.wandb.name,
entity=self.config.wandb.entity,
tags=self.config.wandb.tags,
notes=self.config.wandb.notes,
config={
"model": self.config.base_model,
"lora_r": self.config.lora_r,
"lora_alpha": self.config.lora_alpha,
"learning_rate": self.config.learning_rate,
"batch_size": self.config.per_device_train_batch_size,
"epochs": self.config.num_train_epochs,
},
)
return ["wandb"]
return None
def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer:
@@ -65,15 +66,21 @@ class CustomTrainer:
bf16=torch.cuda.is_bf16_supported(),
optim="adamw_8bit",
report_to=report_to,
save_strategy="steps",
save_steps=50,
save_total_limit=3,
save_strategy=self.config.save_strategy,
save_steps=self.config.save_steps,
save_total_limit=self.config.save_total_limit,
fp16_full_eval=self.config.fp16_full_eval,
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
eval_accumulation_steps=self.config.eval_accumulation_steps,
eval_strategy=self.config.eval_strategy,
eval_steps=self.config.eval_steps,
)
return SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
dataset_text_field="text",
max_seq_length=self.config.max_seq_length,
dataset_num_proc=self.config.dataset_num_proc,
@@ -90,11 +97,6 @@ class CustomTrainer:
return trainer_stats
def _log_training_stats(self, trainer_stats):
try:
import wandb
except ImportError:
wandb = None
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3)
used_percentage = round(used_memory / self.max_memory * 100, 3)
@@ -111,8 +113,8 @@ class CustomTrainer:
f"Peak reserved memory for training % of max memory = {lora_percentage} %."
)
if wandb and self.config.wandb.enabled:
wandb.log(
if self.wandb and self.config.wandb.enabled:
self.wandb.log(
{
"training_time_seconds": trainer_stats.metrics["train_runtime"],
"training_time_minutes": round(