97 lines
2.2 KiB
Python
97 lines
2.2 KiB
Python
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class DataConfig:
|
|
template: str = """Translate this Chinese text to English:
|
|
{}
|
|
---
|
|
Translation:
|
|
{}"""
|
|
train_split: float = 0.9
|
|
max_samples: int | None = 3000
|
|
|
|
|
|
@dataclass
|
|
class WandBConfig:
|
|
enabled: bool = False
|
|
project: str | None = None
|
|
name: str | None = None
|
|
entity: str | None = None
|
|
tags: list[str] = field(default_factory=list)
|
|
notes: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
data: DataConfig = field(default_factory=DataConfig)
|
|
|
|
# model
|
|
base_model: str = "unsloth/Qwen2.5-7B"
|
|
max_seq_length: int = 6144
|
|
dtype: str | None = None
|
|
load_in_4bit: bool = False
|
|
|
|
# LoRA
|
|
lora_r: int = 64
|
|
lora_alpha: int = 128
|
|
lora_dropout: float = 0
|
|
target_modules: list[str] = field(
|
|
default_factory=lambda: [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
]
|
|
)
|
|
use_gradient_checkpointing: str = "unsloth"
|
|
random_state: int = 3407
|
|
use_rslora: bool = False
|
|
loftq_config: dict | None = None
|
|
|
|
# training args
|
|
per_device_train_batch_size: int = 16
|
|
gradient_accumulation_steps: int = 2
|
|
warmup_ratio: float = 0.05
|
|
max_grad_norm: float = 1.0
|
|
num_train_epochs: float = 1
|
|
learning_rate: float = 5e-4
|
|
weight_decay: float = 0
|
|
lr_scheduler_type: str = "cosine"
|
|
logging_steps: int = 1
|
|
|
|
# save
|
|
save_strategy: str = "steps"
|
|
save_steps: float = 100
|
|
save_total_limit: int | None = 3
|
|
|
|
# dataset
|
|
dataset_num_proc: int = 8
|
|
packing: bool = True
|
|
|
|
# eval
|
|
fp16_full_eval: bool = True
|
|
per_device_eval_batch_size: int = 16
|
|
eval_accumulation_steps: int = 4
|
|
eval_strategy: str = "steps"
|
|
eval_steps: int = 5
|
|
|
|
# output
|
|
output_dir: str = "/workspace/output/"
|
|
|
|
def __post_init__(self):
|
|
if not self.target_modules:
|
|
self.target_modules = [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
]
|