diff --git a/config.py b/config.py index af0e167..043e7a3 100644 --- a/config.py +++ b/config.py @@ -1,64 +1,73 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class DataConfig: - # Default configuration template: str = """Translate this Chinese text to English: -{} +{input} === Translation: -{}""" +{output}""" @dataclass class WandBConfig: enabled: bool = True - project: str = "lora-training" + project: str | None = None name: str | None = None entity: str | None = None - tags: list[str] = [] + tags: list[str] = field(default_factory=list) notes: str | None = None @dataclass class TrainingConfig: - wandb: WandBConfig = WandBConfig() - data: DataConfig = DataConfig() + wandb: WandBConfig = field(default_factory=WandBConfig) + data: DataConfig = field(default_factory=DataConfig) # model base_model: str = "unsloth/Qwen2.5-7B" max_seq_length: int = 6144 dtype: str | None = None - load_in_4bit: bool = True + load_in_4bit: bool = False # LoRA - lora_r: int = 16 - lora_alpha: int = 16 + lora_r: int = 256 + lora_alpha: int = 512 lora_dropout: float = 0 - target_modules: list[str] = [] + target_modules: list[str] = field( + default_factory=lambda: [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + ) use_gradient_checkpointing: str = "unsloth" random_state: int = 3407 use_rslora: bool = False loftq_config: dict | None = None # training args - per_device_train_batch_size: int = 32 - gradient_accumulation_steps: int = 1 - warmup_ratio: float = 0.05 + per_device_train_batch_size: int = 16 + gradient_accumulation_steps: int = 2 + warmup_ratio: float = 0.03 max_grad_norm: float = 1.0 - num_train_epochs: float = 0.5 - learning_rate: float = 3e-5 - weight_decay: float = 0.05 - lr_scheduler_type: str = "linear" - logging_steps: int = 5 + num_train_epochs: float = 1 + learning_rate: float = 5e-4 + weight_decay: float = 0 + lr_scheduler_type: str = "cosine" + logging_steps: int = 1 # dataset - dataset_num_proc: int = 2 - packing: bool = False + dataset_num_proc: int = 8 + packing: bool = True # output - output_dir: str = "/output/" + output_dir: str = "/workspace/output/" def __post_init__(self): if not self.target_modules: diff --git a/main.py b/main.py index 96697ab..555461f 100644 --- a/main.py +++ b/main.py @@ -10,9 +10,7 @@ def parse_args(): # wandb args wandb_group = parser.add_argument_group("Weights & Biases") - wandb_group.add_argument( - "--wandb_project", type=str, default="lora-training", help="WandB project name" - ) + wandb_group.add_argument("--wandb_project", type=str, help="WandB project name") wandb_group.add_argument("--wandb_name", type=str, help="WandB run name") wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username") wandb_group.add_argument( @@ -42,7 +40,7 @@ def main(): try: wandb_config = WandBConfig( enabled=args.wandb_project is not None, - project=args.wandb_project or "lora-training", + project=args.wandb_project, name=args.wandb_name, entity=args.wandb_entity, tags=args.wandb_tags, diff --git a/trainer.py b/trainer.py index 67d64d1..414cda1 100644 --- a/trainer.py +++ b/trainer.py @@ -65,6 +65,9 @@ class CustomTrainer: bf16=torch.cuda.is_bf16_supported(), optim="adamw_8bit", report_to=report_to, + save_strategy="steps", + save_steps=50, + save_total_limit=3, ) return SFTTrainer(