import torch from unsloth import FastLanguageModel from transformers import PreTrainedModel, PreTrainedTokenizer from config import TrainingConfig class ModelHandler: def __init__(self, config: TrainingConfig): self.config = config def setup_model(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]: try: model, tokenizer = FastLanguageModel.from_pretrained( model_name=self.config.base_model, max_seq_length=self.config.max_seq_length, dtype=self.config.dtype, load_in_4bit=self.config.load_in_4bit, ) model = self._setup_peft(model) return model, tokenizer except Exception as e: raise Exception(f"Error setting up model: {e}") def _setup_peft(self, model: PreTrainedModel) -> PreTrainedModel: """Setup PEFT config for the model""" try: return FastLanguageModel.get_peft_model( model, r=self.config.lora_r, target_modules=self.config.target_modules, lora_alpha=self.config.lora_alpha, lora_dropout=self.config.lora_dropout, bias="none", use_gradient_checkpointing=self.config.use_gradient_checkpointing, random_state=self.config.random_state, max_seq_length=self.config.max_seq_length, use_rslora=self.config.use_rslora, loftq_config=self.config.loftq_config, ) except Exception as e: raise Exception(f"Error setting up PEFT: {e}")