Files
unsloth-train-scripts/inference.py
2025-02-15 17:59:09 +06:00

54 lines
1.2 KiB
Python

import argparse
from unsloth import FastLanguageModel
from transformers import TextStreamer
def main():
parser = argparse.ArgumentParser(description="Run inference with the model.")
parser.add_argument(
"-i", "--input", type=str, required=True, help="Chinese text to translate."
)
parser.add_argument(
"-s",
"--src",
type=str,
default="/workspace/output",
help="Path to the checkpoint directory.",
)
args = parser.parse_args()
max_seq_length = 6144
dtype = None
load_in_4bit = False
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.src,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)
template = """Translate this Chinese text to English:
{}
---
Translation:
{}"""
inputs = tokenizer(
[template.format(args.input, "")],
return_tensors="pt",
).to("cuda")
text_streamer = TextStreamer(tokenizer)
print("\nGenerating translation...")
_ = model.generate(
**inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True
)
if __name__ == "__main__":
main()