chore: up
This commit is contained in:
16
config.py
16
config.py
@@ -5,9 +5,11 @@ from dataclasses import dataclass, field
|
|||||||
class DataConfig:
|
class DataConfig:
|
||||||
template: str = """Translate this Chinese text to English:
|
template: str = """Translate this Chinese text to English:
|
||||||
{input}
|
{input}
|
||||||
===
|
---
|
||||||
Translation:
|
Translation:
|
||||||
{output}"""
|
{output}"""
|
||||||
|
train_split: float = 0.9
|
||||||
|
max_samples: int | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -62,10 +64,22 @@ class TrainingConfig:
|
|||||||
lr_scheduler_type: str = "cosine"
|
lr_scheduler_type: str = "cosine"
|
||||||
logging_steps: int = 1
|
logging_steps: int = 1
|
||||||
|
|
||||||
|
# save
|
||||||
|
save_strategy: str = "steps"
|
||||||
|
save_steps: float = 100
|
||||||
|
save_total_limit: int | None = 3
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
dataset_num_proc: int = 8
|
dataset_num_proc: int = 8
|
||||||
packing: bool = True
|
packing: bool = True
|
||||||
|
|
||||||
|
# eval
|
||||||
|
fp16_full_eval = True
|
||||||
|
per_device_eval_batch_size = 2
|
||||||
|
eval_accumulation_steps = 4
|
||||||
|
eval_strategy = "steps"
|
||||||
|
eval_steps = 1
|
||||||
|
|
||||||
# output
|
# output
|
||||||
output_dir: str = "/workspace/output/"
|
output_dir: str = "/workspace/output/"
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from config import DataConfig
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
class DataLoader:
|
class DataLoader:
|
||||||
def __init__(self, tokenizer: PreTrainedTokenizer, template: str):
|
def __init__(self, tokenizer: PreTrainedTokenizer, data_config: DataConfig):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self._template = template
|
self.data_config = data_config
|
||||||
|
# self._template = template
|
||||||
|
|
||||||
def load_dataset(self, path: str) -> Dataset:
|
def load_dataset(self, path: str) -> Dataset:
|
||||||
"""Load dataset from local path or Google Drive"""
|
"""Load dataset from local path or Google Drive"""
|
||||||
@@ -28,7 +30,17 @@ class DataLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||||
return self.process_dataset(dataset)
|
|
||||||
|
if max_size := self.data_config.max_samples is not None:
|
||||||
|
dataset = dataset.select(range(min(len(dataset), max_size)))
|
||||||
|
|
||||||
|
processed_dataset = self.process_dataset(dataset)
|
||||||
|
# train/test split
|
||||||
|
split_dataset = processed_dataset.train_test_split(
|
||||||
|
test_size=(1 - self.data_config.train_split), shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return split_dataset
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error loading dataset: {e}")
|
raise Exception(f"Error loading dataset: {e}")
|
||||||
|
|
||||||
@@ -41,7 +53,7 @@ class DataLoader:
|
|||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
for input, output in zip(inputs, outputs):
|
for input, output in zip(inputs, outputs):
|
||||||
text = (
|
text = (
|
||||||
self._template.format(input=input, output=output)
|
self.data_config.template.format(input=input, output=output)
|
||||||
+ self.tokenizer.eos_token
|
+ self.tokenizer.eos_token
|
||||||
)
|
)
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
|
|||||||
134
main.py
134
main.py
@@ -3,6 +3,7 @@ from config import TrainingConfig, WandBConfig
|
|||||||
from data_loader import DataLoader
|
from data_loader import DataLoader
|
||||||
from model_handler import ModelHandler
|
from model_handler import ModelHandler
|
||||||
from trainer import CustomTrainer
|
from trainer import CustomTrainer
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@@ -17,6 +18,9 @@ def parse_args():
|
|||||||
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
|
"--wandb_tags", type=str, nargs="+", help="WandB tags for the run"
|
||||||
)
|
)
|
||||||
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
|
wandb_group.add_argument("--wandb_notes", type=str, help="Notes for the WandB run")
|
||||||
|
wandb_group.add_argument(
|
||||||
|
"--wandb_sweep", action="store_true", help="Enable wandb sweep"
|
||||||
|
)
|
||||||
# rest
|
# rest
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base_model",
|
"--base_model",
|
||||||
@@ -34,12 +38,114 @@ def parse_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def train_single_run(config: TrainingConfig, dataset_path: str):
|
||||||
|
"""Run a single training session with given config"""
|
||||||
|
try:
|
||||||
|
model_handler = ModelHandler(config)
|
||||||
|
model, tokenizer = model_handler.setup_model()
|
||||||
|
|
||||||
|
data_loader = DataLoader(tokenizer, config.data)
|
||||||
|
dataset = data_loader.load_dataset(dataset_path)
|
||||||
|
|
||||||
|
trainer_handler = CustomTrainer(config)
|
||||||
|
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
|
||||||
|
trainer_stats = trainer_handler.train_and_log(trainer)
|
||||||
|
|
||||||
|
model.save_pretrained(config.output_dir)
|
||||||
|
tokenizer.save_pretrained(config.output_dir)
|
||||||
|
|
||||||
|
print("Training completed successfully!")
|
||||||
|
return trainer_stats
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during training: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def run_sweep(base_config: TrainingConfig, dataset_path: str):
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
sweep_config = {
|
||||||
|
"method": "bayes",
|
||||||
|
"metric": {"name": "val_loss", "goal": "minimize"},
|
||||||
|
"parameters": {
|
||||||
|
"base_learning_rate": {
|
||||||
|
"distribution": "log_uniform_values",
|
||||||
|
"min": 1e-5,
|
||||||
|
"max": 1e-3,
|
||||||
|
},
|
||||||
|
"lora_r": {"values": [8, 16, 32, 64]},
|
||||||
|
"lora_alpha": {"values": [16, 32, 64, 128]},
|
||||||
|
"per_device_train_batch_size": {"values": [16]},
|
||||||
|
"gradient_accumulation_steps": {"values": [2, 4, 8]},
|
||||||
|
"warmup_ratio": {
|
||||||
|
"distribution": "uniform",
|
||||||
|
"min": 0.006,
|
||||||
|
"max": 0.015,
|
||||||
|
},
|
||||||
|
"num_train_epochs": {"values": [1]},
|
||||||
|
},
|
||||||
|
"early_terminate": {"type": "hyperband", "min_iter": 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
sweep_id = wandb.sweep(sweep_config, project=base_config.wandb.project)
|
||||||
|
|
||||||
|
def compute_scaled_config(config_dict: dict) -> dict:
|
||||||
|
base_batch_size = 16 # ref point
|
||||||
|
effective_batch_size = (
|
||||||
|
config_dict["per_device_train_batch_size"]
|
||||||
|
* config_dict["gradient_accumulation_steps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply square root scaling for Adam
|
||||||
|
lr_scale_factor = (effective_batch_size / base_batch_size) ** 0.5
|
||||||
|
|
||||||
|
# scale the learning rate
|
||||||
|
config_dict["learning_rate"] = (
|
||||||
|
config_dict["base_learning_rate"] * lr_scale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict["effective_batch_size"] = effective_batch_size
|
||||||
|
config_dict["lr_scale_factor"] = lr_scale_factor
|
||||||
|
|
||||||
|
del config_dict["base_learning_rate"]
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
def sweep_train():
|
||||||
|
with wandb.init() as run:
|
||||||
|
config_dict = (
|
||||||
|
asdict(base_config)
|
||||||
|
if hasattr(base_config, "__dataclass_fields__")
|
||||||
|
else base_config.__dict__.copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict.update(wandb.config)
|
||||||
|
config_dict = compute_scaled_config(config_dict)
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"effective_batch_size": config_dict["effective_batch_size"],
|
||||||
|
"lr_scale_factor": config_dict["lr_scale_factor"],
|
||||||
|
"scaled_learning_rate": config_dict["learning_rate"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict.pop("effective_batch_size")
|
||||||
|
config_dict.pop("lr_scale_factor")
|
||||||
|
|
||||||
|
run_config = TrainingConfig(**config_dict)
|
||||||
|
|
||||||
|
train_single_run(run_config, dataset_path)
|
||||||
|
|
||||||
|
wandb.agent(sweep_id, function=sweep_train)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
try:
|
|
||||||
wandb_config = WandBConfig(
|
wandb_config = WandBConfig(
|
||||||
enabled=args.wandb_project is not None,
|
enabled=bool(args.wandb_project),
|
||||||
project=args.wandb_project,
|
project=args.wandb_project,
|
||||||
name=args.wandb_name,
|
name=args.wandb_name,
|
||||||
entity=args.wandb_entity,
|
entity=args.wandb_entity,
|
||||||
@@ -47,28 +153,14 @@ def main():
|
|||||||
notes=args.wandb_notes,
|
notes=args.wandb_notes,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = TrainingConfig(
|
base_config = TrainingConfig(
|
||||||
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
|
base_model=args.base_model, output_dir=args.output_dir, wandb=wandb_config
|
||||||
)
|
)
|
||||||
|
|
||||||
model_handler = ModelHandler(config)
|
if args.wandb_sweep:
|
||||||
model, tokenizer = model_handler.setup_model()
|
run_sweep(base_config, args.dataset)
|
||||||
|
else:
|
||||||
data_loader = DataLoader(tokenizer, config.data.template)
|
train_single_run(base_config, args.dataset)
|
||||||
dataset = data_loader.load_dataset(args.dataset)
|
|
||||||
|
|
||||||
trainer_handler = CustomTrainer(config)
|
|
||||||
trainer = trainer_handler.create_trainer(model, tokenizer, dataset)
|
|
||||||
trainer_stats = trainer_handler.train_and_log(trainer)
|
|
||||||
|
|
||||||
model.save_pretrained(args.output_dir)
|
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
|
||||||
|
|
||||||
print("Training completed successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during training: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
50
trainer.py
50
trainer.py
@@ -8,6 +8,18 @@ class CustomTrainer:
|
|||||||
def __init__(self, config: TrainingConfig):
|
def __init__(self, config: TrainingConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._setup_gpu_tracking()
|
self._setup_gpu_tracking()
|
||||||
|
self.wandb = None
|
||||||
|
if self.config.wandb.enabled:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
self.wandb = wandb
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to initialize wandb: {e}")
|
||||||
|
|
||||||
def _setup_gpu_tracking(self):
|
def _setup_gpu_tracking(self):
|
||||||
self.gpu_stats = torch.cuda.get_device_properties(0)
|
self.gpu_stats = torch.cuda.get_device_properties(0)
|
||||||
@@ -17,12 +29,9 @@ class CustomTrainer:
|
|||||||
self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
self.max_memory = round(self.gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
|
||||||
|
|
||||||
def _setup_wandb(self):
|
def _setup_wandb(self):
|
||||||
if self.config.wandb.enabled:
|
if self.config.wandb.enabled and self.wandb and self.wandb.run is None:
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
# Initialize wandb
|
# Initialize wandb
|
||||||
wandb.init(
|
self.wandb.init(
|
||||||
project=self.config.wandb.project,
|
project=self.config.wandb.project,
|
||||||
name=self.config.wandb.name,
|
name=self.config.wandb.name,
|
||||||
entity=self.config.wandb.entity,
|
entity=self.config.wandb.entity,
|
||||||
@@ -38,14 +47,6 @@ class CustomTrainer:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return ["wandb"]
|
return ["wandb"]
|
||||||
except ImportError:
|
|
||||||
print(
|
|
||||||
"Warning: wandb not installed. Run `pip install wandb` to enable logging."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to initialize wandb: {e}")
|
|
||||||
return None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer:
|
def create_trainer(self, model, tokenizer, dataset) -> SFTTrainer:
|
||||||
@@ -65,15 +66,21 @@ class CustomTrainer:
|
|||||||
bf16=torch.cuda.is_bf16_supported(),
|
bf16=torch.cuda.is_bf16_supported(),
|
||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
save_strategy="steps",
|
save_strategy=self.config.save_strategy,
|
||||||
save_steps=50,
|
save_steps=self.config.save_steps,
|
||||||
save_total_limit=3,
|
save_total_limit=self.config.save_total_limit,
|
||||||
|
fp16_full_eval=self.config.fp16_full_eval,
|
||||||
|
per_device_eval_batch_size=self.config.per_device_eval_batch_size,
|
||||||
|
eval_accumulation_steps=self.config.eval_accumulation_steps,
|
||||||
|
eval_strategy=self.config.eval_strategy,
|
||||||
|
eval_steps=self.config.eval_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SFTTrainer(
|
return SFTTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset["train"],
|
||||||
|
eval_dataset=dataset["test"],
|
||||||
dataset_text_field="text",
|
dataset_text_field="text",
|
||||||
max_seq_length=self.config.max_seq_length,
|
max_seq_length=self.config.max_seq_length,
|
||||||
dataset_num_proc=self.config.dataset_num_proc,
|
dataset_num_proc=self.config.dataset_num_proc,
|
||||||
@@ -90,11 +97,6 @@ class CustomTrainer:
|
|||||||
return trainer_stats
|
return trainer_stats
|
||||||
|
|
||||||
def _log_training_stats(self, trainer_stats):
|
def _log_training_stats(self, trainer_stats):
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
except ImportError:
|
|
||||||
wandb = None
|
|
||||||
|
|
||||||
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
|
||||||
used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3)
|
used_memory_for_lora = round(used_memory - self.start_gpu_memory, 3)
|
||||||
used_percentage = round(used_memory / self.max_memory * 100, 3)
|
used_percentage = round(used_memory / self.max_memory * 100, 3)
|
||||||
@@ -111,8 +113,8 @@ class CustomTrainer:
|
|||||||
f"Peak reserved memory for training % of max memory = {lora_percentage} %."
|
f"Peak reserved memory for training % of max memory = {lora_percentage} %."
|
||||||
)
|
)
|
||||||
|
|
||||||
if wandb and self.config.wandb.enabled:
|
if self.wandb and self.config.wandb.enabled:
|
||||||
wandb.log(
|
self.wandb.log(
|
||||||
{
|
{
|
||||||
"training_time_seconds": trainer_stats.metrics["train_runtime"],
|
"training_time_seconds": trainer_stats.metrics["train_runtime"],
|
||||||
"training_time_minutes": round(
|
"training_time_minutes": round(
|
||||||
|
|||||||
Reference in New Issue
Block a user