Files
unsloth-train-scripts/config.py
2025-02-14 02:13:40 +06:00

83 lines
1.9 KiB
Python

from dataclasses import dataclass, field
@dataclass
class DataConfig:
template: str = """Translate this Chinese text to English:
{input}
===
Translation:
{output}"""
@dataclass
class WandBConfig:
enabled: bool = True
project: str | None = None
name: str | None = None
entity: str | None = None
tags: list[str] = field(default_factory=list)
notes: str | None = None
@dataclass
class TrainingConfig:
wandb: WandBConfig = field(default_factory=WandBConfig)
data: DataConfig = field(default_factory=DataConfig)
# model
base_model: str = "unsloth/Qwen2.5-7B"
max_seq_length: int = 6144
dtype: str | None = None
load_in_4bit: bool = False
# LoRA
lora_r: int = 256
lora_alpha: int = 512
lora_dropout: float = 0
target_modules: list[str] = field(
default_factory=lambda: [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
)
use_gradient_checkpointing: str = "unsloth"
random_state: int = 3407
use_rslora: bool = False
loftq_config: dict | None = None
# training args
per_device_train_batch_size: int = 16
gradient_accumulation_steps: int = 2
warmup_ratio: float = 0.03
max_grad_norm: float = 1.0
num_train_epochs: float = 1
learning_rate: float = 5e-4
weight_decay: float = 0
lr_scheduler_type: str = "cosine"
logging_steps: int = 1
# dataset
dataset_num_proc: int = 8
packing: bool = True
# output
output_dir: str = "/workspace/output/"
def __post_init__(self):
if not self.target_modules:
self.target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]