chore: up

This commit is contained in:
2025-02-14 02:13:40 +06:00
parent f013e8efe6
commit 45f45f4bdb
3 changed files with 37 additions and 27 deletions

View File

@@ -1,64 +1,73 @@
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
class DataConfig: class DataConfig:
# Default configuration
template: str = """Translate this Chinese text to English: template: str = """Translate this Chinese text to English:
{} {input}
=== ===
Translation: Translation:
{}""" {output}"""
@dataclass @dataclass
class WandBConfig: class WandBConfig:
enabled: bool = True enabled: bool = True
project: str = "lora-training" project: str | None = None
name: str | None = None name: str | None = None
entity: str | None = None entity: str | None = None
tags: list[str] = [] tags: list[str] = field(default_factory=list)
notes: str | None = None notes: str | None = None
@dataclass @dataclass
class TrainingConfig: class TrainingConfig:
wandb: WandBConfig = WandBConfig() wandb: WandBConfig = field(default_factory=WandBConfig)
data: DataConfig = DataConfig() data: DataConfig = field(default_factory=DataConfig)
# model # model
base_model: str = "unsloth/Qwen2.5-7B" base_model: str = "unsloth/Qwen2.5-7B"
max_seq_length: int = 6144 max_seq_length: int = 6144
dtype: str | None = None dtype: str | None = None
load_in_4bit: bool = True load_in_4bit: bool = False
# LoRA # LoRA
lora_r: int = 16 lora_r: int = 256
lora_alpha: int = 16 lora_alpha: int = 512
lora_dropout: float = 0 lora_dropout: float = 0
target_modules: list[str] = [] target_modules: list[str] = field(
default_factory=lambda: [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
)
use_gradient_checkpointing: str = "unsloth" use_gradient_checkpointing: str = "unsloth"
random_state: int = 3407 random_state: int = 3407
use_rslora: bool = False use_rslora: bool = False
loftq_config: dict | None = None loftq_config: dict | None = None
# training args # training args
per_device_train_batch_size: int = 32 per_device_train_batch_size: int = 16
gradient_accumulation_steps: int = 1 gradient_accumulation_steps: int = 2
warmup_ratio: float = 0.05 warmup_ratio: float = 0.03
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
num_train_epochs: float = 0.5 num_train_epochs: float = 1
learning_rate: float = 3e-5 learning_rate: float = 5e-4
weight_decay: float = 0.05 weight_decay: float = 0
lr_scheduler_type: str = "linear" lr_scheduler_type: str = "cosine"
logging_steps: int = 5 logging_steps: int = 1
# dataset # dataset
dataset_num_proc: int = 2 dataset_num_proc: int = 8
packing: bool = False packing: bool = True
# output # output
output_dir: str = "/output/" output_dir: str = "/workspace/output/"
def __post_init__(self): def __post_init__(self):
if not self.target_modules: if not self.target_modules:

View File

@@ -10,9 +10,7 @@ def parse_args():
# wandb args # wandb args
wandb_group = parser.add_argument_group("Weights & Biases") wandb_group = parser.add_argument_group("Weights & Biases")
wandb_group.add_argument( wandb_group.add_argument("--wandb_project", type=str, help="WandB project name")
"--wandb_project", type=str, default="lora-training", help="WandB project name"
)
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name") wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username") wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
wandb_group.add_argument( wandb_group.add_argument(
@@ -42,7 +40,7 @@ def main():
try: try:
wandb_config = WandBConfig( wandb_config = WandBConfig(
enabled=args.wandb_project is not None, enabled=args.wandb_project is not None,
project=args.wandb_project or "lora-training", project=args.wandb_project,
name=args.wandb_name, name=args.wandb_name,
entity=args.wandb_entity, entity=args.wandb_entity,
tags=args.wandb_tags, tags=args.wandb_tags,

View File

@@ -65,6 +65,9 @@ class CustomTrainer:
bf16=torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(),
optim="adamw_8bit", optim="adamw_8bit",
report_to=report_to, report_to=report_to,
save_strategy="steps",
save_steps=50,
save_total_limit=3,
) )
return SFTTrainer( return SFTTrainer(