43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
import torch
|
|
from unsloth import FastLanguageModel
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
from config import TrainingConfig
|
|
|
|
|
|
class ModelHandler:
|
|
def __init__(self, config: TrainingConfig):
|
|
self.config = config
|
|
|
|
def setup_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
try:
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=self.config.base_model,
|
|
max_seq_length=self.config.max_seq_length,
|
|
dtype=self.config.dtype,
|
|
load_in_4bit=self.config.load_in_4bit,
|
|
)
|
|
|
|
model = self._setup_peft(model)
|
|
return model, tokenizer
|
|
except Exception as e:
|
|
raise Exception(f"Error setting up model: {e}")
|
|
|
|
def _setup_peft(self, model: PreTrainedModel) -> PreTrainedModel:
|
|
"""Setup PEFT config for the model"""
|
|
try:
|
|
return FastLanguageModel.get_peft_model(
|
|
model,
|
|
r=self.config.lora_r,
|
|
target_modules=self.config.target_modules,
|
|
lora_alpha=self.config.lora_alpha,
|
|
lora_dropout=self.config.lora_dropout,
|
|
bias="none",
|
|
use_gradient_checkpointing=self.config.use_gradient_checkpointing,
|
|
random_state=self.config.random_state,
|
|
max_seq_length=self.config.max_seq_length,
|
|
use_rslora=self.config.use_rslora,
|
|
loftq_config=self.config.loftq_config,
|
|
)
|
|
except Exception as e:
|
|
raise Exception(f"Error setting up PEFT: {e}")
|