From efd62d7c9409c9c7e4a23a45644b6aa742314234 Mon Sep 17 00:00:00 2001 From: kuwoyuki Date: Fri, 14 Feb 2025 21:50:55 +0600 Subject: [PATCH] chore: up --- config.py | 16 +++++- data_loader.py | 20 ++++++-- main.py | 132 +++++++++++++++++++++++++++++++++++++++++-------- trainer.py | 82 +++++++++++++++--------------- 4 files changed, 185 insertions(+), 65 deletions(-) diff --git a/config.py b/config.py index 043e7a3..6295d2d 100644 --- a/config.py +++ b/config.py @@ -5,9 +5,11 @@ from dataclasses import dataclass, field class DataConfig: template: str = """Translate this Chinese text to English: {input} -=== +--- Translation: {output}""" + train_split: float = 0.9 + max_samples: int | None = None @dataclass @@ -62,10 +64,22 @@ class TrainingConfig: lr_scheduler_type: str = "cosine" logging_steps: int = 1 + # save + save_strategy: str = "steps" + save_steps: float = 100 + save_total_limit: int | None = 3 + # dataset dataset_num_proc: int = 8 packing: bool = True + # eval + fp16_full_eval = True + per_device_eval_batch_size = 2 + eval_accumulation_steps = 4 + eval_strategy = "steps" + eval_steps = 1 + # output output_dir: str = "/workspace/output/" diff --git a/data_loader.py b/data_loader.py index 72815c1..ab788fb 100644 --- a/data_loader.py +++ b/data_loader.py @@ -1,13 +1,15 @@ import os from typing import Any +from config import DataConfig from datasets import Dataset, load_dataset from transformers import PreTrainedTokenizer class DataLoader: - def __init__(self, tokenizer: PreTrainedTokenizer, template: str): + def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig): self.tokenizer = tokenizer - self._template = template + self.data_config = data_config + # self._template = template def load_dataset(self, path: str) -> Dataset: """Load dataset from local path or Google Drive""" @@ -28,7 +30,17 @@ class DataLoader: try: dataset = load_dataset("json", data_files=dataset_path, split="train") - return self.process_dataset(dataset) + + if max_size := self.data_config.max_samples is not None: + dataset = dataset.select(range(min(len(dataset), max_size))) + + processed_dataset = self.process_dataset(dataset) + # train/test split + split_dataset = processed_dataset.train_test_split( + test_size=(1 - self.data_config.train_split), shuffle=False + ) + + return split_dataset except Exception as e: raise Exception(f"Error loading dataset: {e}") @@ -41,7 +53,7 @@ class DataLoader: texts: list[str] = [] for input, output in zip(inputs, outputs): text = ( - self._template.format(input=input, output=output) + self.data_config.template.format(input=input, output=output) + self.tokenizer.eos_token ) texts.append(text) diff --git a/main.py b/main.py index 555461f..07d76aa 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from config import TrainingConfig, WandBConfig from data_loader import DataLoader from model_handler import ModelHandler from trainer import CustomTrainer +from dataclasses import asdict def parse_args(): @@ -17,6 +18,9 @@ def parse_args(): "--wandb_tags", type=str, nargs="+", help="WandB tags for the run" ) wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run") + wandb_group.add_argument( + "--wandb_sweep", action="store_true", help="Enable wandb sweep" + ) # rest parser.add_argument( "--base_model", @@ -34,42 +38,130 @@ def parse_args(): return parser.parse_args() -def main(): - args = parse_args() - +def train_single_run(config: TrainingConfig, dataset_path: str): + """Run a single training session with given config""" try: - wandb_config = WandBConfig( - enabled=args.wandb_project is not None, - project=args.wandb_project, - name=args.wandb_name, - entity=args.wandb_entity, - tags=args.wandb_tags, - notes=args.wandb_notes, - ) - - config = TrainingConfig( - base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config - ) - model_handler = ModelHandler(config) model, tokenizer = model_handler.setup_model() - data_loader = DataLoader(tokenizer, config.data.template) - dataset = data_loader.load_dataset(args.dataset) + data_loader = DataLoader(tokenizer, config.data) + dataset = data_loader.load_dataset(dataset_path) trainer_handler = CustomTrainer(config) trainer = trainer_handler.create_trainer(model, tokenizer, dataset) trainer_stats = trainer_handler.train_and_log(trainer) - model.save_pretrained(args.output_dir) - tokenizer.save_pretrained(args.output_dir) + model.save_pretrained(config.output_dir) + tokenizer.save_pretrained(config.output_dir) print("Training completed successfully!") + return trainer_stats except Exception as e: print(f"Error during training: {e}") raise +def run_sweep(base_config: TrainingConfig, dataset_path: str): + import wandb + + sweep_config = { + "method": "bayes", + "metric": {"name": "val_loss", "goal": "minimize"}, + "parameters": { + "base_learning_rate": { + "distribution": "log_uniform_values", + "min": 1e-5, + "max": 1e-3, + }, + "lora_r": {"values": [8, 16, 32, 64]}, + "lora_alpha": {"values": [16, 32, 64, 128]}, + "per_device_train_batch_size": {"values": [16]}, + "gradient_accumulation_steps": {"values": [2, 4, 8]}, + "warmup_ratio": { + "distribution": "uniform", + "min": 0.006, + "max": 0.015, + }, + "num_train_epochs": {"values": [1]}, + }, + "early_terminate": {"type": "hyperband", "min_iter": 100}, + } + + sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project) + + def compute_scaled_config(config_dict: dict) -> dict: + base_batch_size = 16 # ref point + effective_batch_size = ( + config_dict["per_device_train_batch_size"] + * config_dict["gradient_accumulation_steps"] + ) + + # apply square root scaling for Adam + lr_scale_factor = (effective_batch_size / base_batch_size) ** 0.5 + + # scale the learning rate + config_dict["learning_rate"] = ( + config_dict["base_learning_rate"] * lr_scale_factor + ) + + config_dict["effective_batch_size"] = effective_batch_size + config_dict["lr_scale_factor"] = lr_scale_factor + + del config_dict["base_learning_rate"] + + return config_dict + + def sweep_train(): + with wandb.init() as run: + config_dict = ( + asdict(base_config) + if hasattr(base_config, "__dataclass_fields__") + else base_config.__dict__.copy() + ) + + config_dict.update(wandb.config) + config_dict = compute_scaled_config(config_dict) + + wandb.log( + { + "effective_batch_size": config_dict["effective_batch_size"], + "lr_scale_factor": config_dict["lr_scale_factor"], + "scaled_learning_rate": config_dict["learning_rate"], + } + ) + + config_dict.pop("effective_batch_size") + config_dict.pop("lr_scale_factor") + + run_config = TrainingConfig(**config_dict) + + train_single_run(run_config, dataset_path) + + wandb.agent(sweep_id, function=sweep_train) + + +def main(): + args = parse_args() + + wandb_config = WandBConfig( + enabled=bool(args.wandb_project), + project=args.wandb_project, + name=args.wandb_name, + entity=args.wandb_entity, + tags=args.wandb_tags, + notes=args.wandb_notes, + ) + + base_config = TrainingConfig( + base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config + ) + + if args.wandb_sweep: + run_sweep(base_config, args.dataset) + else: + train_single_run(base_config, args.dataset) + + if __name__ == "__main__": main() diff --git a/trainer.py b/trainer.py index 414cda1..7220376 100644 --- a/trainer.py +++ b/trainer.py @@ -8,6 +8,18 @@ class CustomTrainer: def __init__(self, config: TrainingConfig): self.config = config self._setup_gpu_tracking() + self.wandb = None + if self.config.wandb.enabled: + try: + import wandb + + self.wandb = wandb + except ImportError: + print( + "Warning: wandb not installed. Run `pip install wandb` to enable logging." + ) + except Exception as e: + print(f"Warning: Failed to initialize wandb: {e}") def _setup_gpu_tracking(self): self.gpu_stats = torch.cuda.get_device_properties(0) @@ -17,35 +29,24 @@ class CustomTrainer: self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3) def _setup_wandb(self): - if self.config.wandb.enabled: - try: - import wandb - - # Initialize wandb - wandb.init( - project=self.config.wandb.project, - name=self.config.wandb.name, - entity=self.config.wandb.entity, - tags=self.config.wandb.tags, - notes=self.config.wandb.notes, - config={ - "model": self.config.base_model, - "lora_r": self.config.lora_r, - "lora_alpha": self.config.lora_alpha, - "learning_rate": self.config.learning_rate, - "batch_size": self.config.per_device_train_batch_size, - "epochs": self.config.num_train_epochs, - }, - ) - return ["wandb"] - except ImportError: - print( - "Warning: wandb not installed. Run `pip install wandb` to enable logging." - ) - return None - except Exception as e: - print(f"Warning: Failed to initialize wandb: {e}") - return None + if self.config.wandb.enabled and self.wandb and self.wandb.run is None: + # Initialize wandb + self.wandb.init( + project=self.config.wandb.project, + name=self.config.wandb.name, + entity=self.config.wandb.entity, + tags=self.config.wandb.tags, + notes=self.config.wandb.notes, + config={ + "model": self.config.base_model, + "lora_r": self.config.lora_r, + "lora_alpha": self.config.lora_alpha, + "learning_rate": self.config.learning_rate, + "batch_size": self.config.per_device_train_batch_size, + "epochs": self.config.num_train_epochs, + }, + ) + return ["wandb"] return None def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer: @@ -65,15 +66,21 @@ class CustomTrainer: bf16=torch.cuda.is_bf16_supported(), optim="adamw_8bit", report_to=report_to, - save_strategy="steps", - save_steps=50, - save_total_limit=3, + save_strategy=self.config.save_strategy, + save_steps=self.config.save_steps, + save_total_limit=self.config.save_total_limit, + fp16_full_eval=self.config.fp16_full_eval, + per_device_eval_batch_size=self.config.per_device_eval_batch_size, + eval_accumulation_steps=self.config.eval_accumulation_steps, + eval_strategy=self.config.eval_strategy, + eval_steps=self.config.eval_steps, ) return SFTTrainer( model=model, tokenizer=tokenizer, - train_dataset=dataset, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], dataset_text_field="text", max_seq_length=self.config.max_seq_length, dataset_num_proc=self.config.dataset_num_proc, @@ -90,11 +97,6 @@ class CustomTrainer: return trainer_stats def _log_training_stats(self, trainer_stats): - try: - import wandb - except ImportError: - wandb = None - used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3) used_percentage = round(used_memory / self.max_memory * 100, 3) @@ -111,8 +113,8 @@ class CustomTrainer: f"Peak reserved memory for training % of max memory = {lora_percentage} %." ) - if wandb and self.config.wandb.enabled: - wandb.log( + if self.wandb and self.config.wandb.enabled: + self.wandb.log( { "training_time_seconds": trainer_stats.metrics["train_runtime"], "training_time_minutes": round(