Files
unsloth-train-scripts/model_handler.py
2025-02-13 21:42:03 +06:00

43 lines
1.6 KiB
Python

import torch
from unsloth import FastLanguageModel
from transformers import PreTrainedModel, PreTrainedTokenizer
from config import TrainingConfig
class ModelHandler:
def __init__(self, config: TrainingConfig):
self.config = config
def setup_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config.base_model,
max_seq_length=self.config.max_seq_length,
dtype=self.config.dtype,
load_in_4bit=self.config.load_in_4bit,
)
model = self._setup_peft(model)
return model, tokenizer
except Exception as e:
raise Exception(f"Error setting up model: {e}")
def _setup_peft(self, model: PreTrainedModel) -> PreTrainedModel:
"""Setup PEFT config for the model"""
try:
return FastLanguageModel.get_peft_model(
model,
r=self.config.lora_r,
target_modules=self.config.target_modules,
lora_alpha=self.config.lora_alpha,
lora_dropout=self.config.lora_dropout,
bias="none",
use_gradient_checkpointing=self.config.use_gradient_checkpointing,
random_state=self.config.random_state,
max_seq_length=self.config.max_seq_length,
use_rslora=self.config.use_rslora,
loftq_config=self.config.loftq_config,
)
except Exception as e:
raise Exception(f"Error setting up PEFT: {e}")