chore: trainer
This commit is contained in:
42
model_handler.py
Normal file
42
model_handler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from unsloth import FastLanguageModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from config import TrainingConfig
|
||||
|
||||
|
||||
class ModelHandler:
|
||||
def __init__(self, config: TrainingConfig):
|
||||
self.config = config
|
||||
|
||||
def setup_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
try:
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=self.config.base_model,
|
||||
max_seq_length=self.config.max_seq_length,
|
||||
dtype=self.config.dtype,
|
||||
load_in_4bit=self.config.load_in_4bit,
|
||||
)
|
||||
|
||||
model = self._setup_peft(model)
|
||||
return model, tokenizer
|
||||
except Exception as e:
|
||||
raise Exception(f"Error setting up model: {e}")
|
||||
|
||||
def _setup_peft(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||
"""Setup PEFT config for the model"""
|
||||
try:
|
||||
return FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=self.config.lora_r,
|
||||
target_modules=self.config.target_modules,
|
||||
lora_alpha=self.config.lora_alpha,
|
||||
lora_dropout=self.config.lora_dropout,
|
||||
bias="none",
|
||||
use_gradient_checkpointing=self.config.use_gradient_checkpointing,
|
||||
random_state=self.config.random_state,
|
||||
max_seq_length=self.config.max_seq_length,
|
||||
use_rslora=self.config.use_rslora,
|
||||
loftq_config=self.config.loftq_config,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error setting up PEFT: {e}")
|
||||
Reference in New Issue
Block a user