chore: up
This commit is contained in:
55
config.py
55
config.py
@@ -1,64 +1,73 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataConfig:
|
class DataConfig:
|
||||||
# Default configuration
|
|
||||||
template: str = """Translate this Chinese text to English:
|
template: str = """Translate this Chinese text to English:
|
||||||
{}
|
{input}
|
||||||
===
|
===
|
||||||
Translation:
|
Translation:
|
||||||
{}"""
|
{output}"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WandBConfig:
|
class WandBConfig:
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
project: str = "lora-training"
|
project: str | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
entity: str | None = None
|
entity: str | None = None
|
||||||
tags: list[str] = []
|
tags: list[str] = field(default_factory=list)
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingConfig:
|
class TrainingConfig:
|
||||||
wandb: WandBConfig = WandBConfig()
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
data: DataConfig = DataConfig()
|
data: DataConfig = field(default_factory=DataConfig)
|
||||||
|
|
||||||
# model
|
# model
|
||||||
base_model: str = "unsloth/Qwen2.5-7B"
|
base_model: str = "unsloth/Qwen2.5-7B"
|
||||||
max_seq_length: int = 6144
|
max_seq_length: int = 6144
|
||||||
dtype: str | None = None
|
dtype: str | None = None
|
||||||
load_in_4bit: bool = True
|
load_in_4bit: bool = False
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_r: int = 16
|
lora_r: int = 256
|
||||||
lora_alpha: int = 16
|
lora_alpha: int = 512
|
||||||
lora_dropout: float = 0
|
lora_dropout: float = 0
|
||||||
target_modules: list[str] = []
|
target_modules: list[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
)
|
||||||
use_gradient_checkpointing: str = "unsloth"
|
use_gradient_checkpointing: str = "unsloth"
|
||||||
random_state: int = 3407
|
random_state: int = 3407
|
||||||
use_rslora: bool = False
|
use_rslora: bool = False
|
||||||
loftq_config: dict | None = None
|
loftq_config: dict | None = None
|
||||||
|
|
||||||
# training args
|
# training args
|
||||||
per_device_train_batch_size: int = 32
|
per_device_train_batch_size: int = 16
|
||||||
gradient_accumulation_steps: int = 1
|
gradient_accumulation_steps: int = 2
|
||||||
warmup_ratio: float = 0.05
|
warmup_ratio: float = 0.03
|
||||||
max_grad_norm: float = 1.0
|
max_grad_norm: float = 1.0
|
||||||
num_train_epochs: float = 0.5
|
num_train_epochs: float = 1
|
||||||
learning_rate: float = 3e-5
|
learning_rate: float = 5e-4
|
||||||
weight_decay: float = 0.05
|
weight_decay: float = 0
|
||||||
lr_scheduler_type: str = "linear"
|
lr_scheduler_type: str = "cosine"
|
||||||
logging_steps: int = 5
|
logging_steps: int = 1
|
||||||
|
|
||||||
# dataset
|
# dataset
|
||||||
dataset_num_proc: int = 2
|
dataset_num_proc: int = 8
|
||||||
packing: bool = False
|
packing: bool = True
|
||||||
|
|
||||||
# output
|
# output
|
||||||
output_dir: str = "/output/"
|
output_dir: str = "/workspace/output/"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.target_modules:
|
if not self.target_modules:
|
||||||
|
|||||||
6
main.py
6
main.py
@@ -10,9 +10,7 @@ def parse_args():
|
|||||||
|
|
||||||
# wandb args
|
# wandb args
|
||||||
wandb_group = parser.add_argument_group("Weights & Biases")
|
wandb_group = parser.add_argument_group("Weights & Biases")
|
||||||
wandb_group.add_argument(
|
wandb_group.add_argument("--wandb_project", type=str, help="WandB project name")
|
||||||
"--wandb_project", type=str, default="lora-training", help="WandB project name"
|
|
||||||
)
|
|
||||||
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
|
wandb_group.add_argument("--wandb_name", type=str, help="WandB run name")
|
||||||
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
|
wandb_group.add_argument("--wandb_entity", type=str, help="WandB entity/username")
|
||||||
wandb_group.add_argument(
|
wandb_group.add_argument(
|
||||||
@@ -42,7 +40,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
wandb_config = WandBConfig(
|
wandb_config = WandBConfig(
|
||||||
enabled=args.wandb_project is not None,
|
enabled=args.wandb_project is not None,
|
||||||
project=args.wandb_project or "lora-training",
|
project=args.wandb_project,
|
||||||
name=args.wandb_name,
|
name=args.wandb_name,
|
||||||
entity=args.wandb_entity,
|
entity=args.wandb_entity,
|
||||||
tags=args.wandb_tags,
|
tags=args.wandb_tags,
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ class CustomTrainer:
|
|||||||
bf16=torch.cuda.is_bf16_supported(),
|
bf16=torch.cuda.is_bf16_supported(),
|
||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
|
save_strategy="steps",
|
||||||
|
save_steps=50,
|
||||||
|
save_total_limit=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SFTTrainer(
|
return SFTTrainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user