chore: trainer

This commit is contained in:
2025-02-13 21:42:03 +06:00
parent df7883dde3
commit f013e8efe6
7 changed files with 431 additions and 0 deletions

73
config.py Normal file
View File

@@ -0,0 +1,73 @@
from dataclasses import dataclass
@dataclass
class DataConfig:
# Default configuration
template: str = """Translate this Chinese text to English:
{}
===
Translation:
{}"""
@dataclass
class WandBConfig:
enabled: bool = True
project: str = "lora-training"
name: str | None = None
entity: str | None = None
tags: list[str] = []
notes: str | None = None
@dataclass
class TrainingConfig:
wandb: WandBConfig = WandBConfig()
data: DataConfig = DataConfig()
# model
base_model: str = "unsloth/Qwen2.5-7B"
max_seq_length: int = 6144
dtype: str | None = None
load_in_4bit: bool = True
# LoRA
lora_r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0
target_modules: list[str] = []
use_gradient_checkpointing: str = "unsloth"
random_state: int = 3407
use_rslora: bool = False
loftq_config: dict | None = None
# training args
per_device_train_batch_size: int = 32
gradient_accumulation_steps: int = 1
warmup_ratio: float = 0.05
max_grad_norm: float = 1.0
num_train_epochs: float = 0.5
learning_rate: float = 3e-5
weight_decay: float = 0.05
lr_scheduler_type: str = "linear"
logging_steps: int = 5
# dataset
dataset_num_proc: int = 2
packing: bool = False
# output
output_dir: str = "/output/"
def __post_init__(self):
if not self.target_modules:
self.target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]