chore: _
This commit is contained in:
26
train.py
26
train.py
@@ -1,5 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -13,7 +13,8 @@ def load_data(path):
|
|||||||
import gdown
|
import gdown
|
||||||
|
|
||||||
local_path = "downloaded_dataset.json"
|
local_path = "downloaded_dataset.json"
|
||||||
gdown.download(url=path, output=local_path, fuzzy=True)
|
if not os.path.exists(local_path):
|
||||||
|
gdown.download(url=path, output=local_path, fuzzy=True)
|
||||||
dataset_path = local_path
|
dataset_path = local_path
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install gdown: pip install gdown")
|
raise ImportError("Please install gdown: pip install gdown")
|
||||||
@@ -23,7 +24,6 @@ def load_data(path):
|
|||||||
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
dataset = load_dataset("json", data_files=dataset_path, split="train")
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -45,7 +45,7 @@ def main():
|
|||||||
|
|
||||||
max_seq_length = 16384 # Choose any! We auto support RoPE Scaling internally!
|
max_seq_length = 16384 # Choose any! We auto support RoPE Scaling internally!
|
||||||
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||||
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.
|
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
|
||||||
|
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
model_name=args.base_model,
|
model_name=args.base_model,
|
||||||
@@ -86,7 +86,7 @@ def main():
|
|||||||
{}"""
|
{}"""
|
||||||
|
|
||||||
DATASET_PATH = args.dataset
|
DATASET_PATH = args.dataset
|
||||||
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
|
dataset = load_data(DATASET_PATH)
|
||||||
|
|
||||||
EOS_TOKEN = tokenizer.eos_token
|
EOS_TOKEN = tokenizer.eos_token
|
||||||
print(f"EOS Token: {EOS_TOKEN}")
|
print(f"EOS Token: {EOS_TOKEN}")
|
||||||
@@ -112,20 +112,23 @@ def main():
|
|||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
dataset_text_field="text",
|
dataset_text_field="text",
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=max_seq_length,
|
||||||
|
dataset_num_proc = 2,
|
||||||
packing=False,
|
packing=False,
|
||||||
args=TrainingArguments(
|
args=TrainingArguments(
|
||||||
per_device_train_batch_size=16,
|
per_device_train_batch_size=32,
|
||||||
gradient_accumulation_steps=2,
|
gradient_accumulation_steps=1,
|
||||||
|
# warmup_steps=10,
|
||||||
|
# max_steps=int(31583 * 0.5 / 40),
|
||||||
warmup_ratio=0.05,
|
warmup_ratio=0.05,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
num_train_epochs=1,
|
num_train_epochs=0.5,
|
||||||
learning_rate=1e-4,
|
learning_rate=3e-5,
|
||||||
fp16=not torch.cuda.is_bf16_supported(),
|
fp16=not torch.cuda.is_bf16_supported(),
|
||||||
bf16=torch.cuda.is_bf16_supported(),
|
bf16=torch.cuda.is_bf16_supported(),
|
||||||
logging_steps=50,
|
logging_steps=5,
|
||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
weight_decay=0.05,
|
weight_decay=0.05,
|
||||||
lr_scheduler_type="cosine",
|
lr_scheduler_type="linear",
|
||||||
seed=3407,
|
seed=3407,
|
||||||
output_dir="/output/",
|
output_dir="/output/",
|
||||||
report_to=None,
|
report_to=None,
|
||||||
@@ -173,3 +176,4 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user