chore: ss
This commit is contained in:
24
config.py
24
config.py
@@ -4,17 +4,17 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
template: str = """Translate this Chinese text to English:
|
||||
{input}
|
||||
{}
|
||||
---
|
||||
Translation:
|
||||
{output}"""
|
||||
{}"""
|
||||
train_split: float = 0.9
|
||||
max_samples: int | None = None
|
||||
max_samples: int | None = 3000
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
enabled: bool = True
|
||||
enabled: bool = False
|
||||
project: str | None = None
|
||||
name: str | None = None
|
||||
entity: str | None = None
|
||||
@@ -34,8 +34,8 @@ class TrainingConfig:
|
||||
load_in_4bit: bool = False
|
||||
|
||||
# LoRA
|
||||
lora_r: int = 256
|
||||
lora_alpha: int = 512
|
||||
lora_r: int = 64
|
||||
lora_alpha: int = 128
|
||||
lora_dropout: float = 0
|
||||
target_modules: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
@@ -56,7 +56,7 @@ class TrainingConfig:
|
||||
# training args
|
||||
per_device_train_batch_size: int = 16
|
||||
gradient_accumulation_steps: int = 2
|
||||
warmup_ratio: float = 0.03
|
||||
warmup_ratio: float = 0.05
|
||||
max_grad_norm: float = 1.0
|
||||
num_train_epochs: float = 1
|
||||
learning_rate: float = 5e-4
|
||||
@@ -74,11 +74,11 @@ class TrainingConfig:
|
||||
packing: bool = True
|
||||
|
||||
# eval
|
||||
fp16_full_eval = True
|
||||
per_device_eval_batch_size = 2
|
||||
eval_accumulation_steps = 4
|
||||
eval_strategy = "steps"
|
||||
eval_steps = 1
|
||||
fp16_full_eval: bool = True
|
||||
per_device_eval_batch_size: int = 16
|
||||
eval_accumulation_steps: int = 4
|
||||
eval_strategy: str = "steps"
|
||||
eval_steps: int = 5
|
||||
|
||||
# output
|
||||
output_dir: str = "/workspace/output/"
|
||||
|
||||
Reference in New Issue
Block a user