from dataclasses import dataclass @dataclass class DataConfig: # Default configuration template: str = """Translate this Chinese text to English: {} === Translation: {}""" @dataclass class WandBConfig: enabled: bool = True project: str = "lora-training" name: str | None = None entity: str | None = None tags: list[str] = [] notes: str | None = None @dataclass class TrainingConfig: wandb: WandBConfig = WandBConfig() data: DataConfig = DataConfig() # model base_model: str = "unsloth/Qwen2.5-7B" max_seq_length: int = 6144 dtype: str | None = None load_in_4bit: bool = True # LoRA lora_r: int = 16 lora_alpha: int = 16 lora_dropout: float = 0 target_modules: list[str] = [] use_gradient_checkpointing: str = "unsloth" random_state: int = 3407 use_rslora: bool = False loftq_config: dict | None = None # training args per_device_train_batch_size: int = 32 gradient_accumulation_steps: int = 1 warmup_ratio: float = 0.05 max_grad_norm: float = 1.0 num_train_epochs: float = 0.5 learning_rate: float = 3e-5 weight_decay: float = 0.05 lr_scheduler_type: str = "linear" logging_steps: int = 5 # dataset dataset_num_proc: int = 2 packing: bool = False # output output_dir: str = "/output/" def __post_init__(self): if not self.target_modules: self.target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ]