54 lines
1.2 KiB
Python
54 lines
1.2 KiB
Python
import argparse
|
|
from unsloth import FastLanguageModel
|
|
from transformers import TextStreamer
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run inference with the model.")
|
|
parser.add_argument(
|
|
"-i", "--input", type=str, required=True, help="Chinese text to translate."
|
|
)
|
|
parser.add_argument(
|
|
"-s",
|
|
"--src",
|
|
type=str,
|
|
default="/workspace/output",
|
|
help="Path to the checkpoint directory.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
max_seq_length = 6144
|
|
dtype = None
|
|
load_in_4bit = False
|
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=args.src,
|
|
max_seq_length=max_seq_length,
|
|
dtype=dtype,
|
|
load_in_4bit=load_in_4bit,
|
|
)
|
|
|
|
FastLanguageModel.for_inference(model)
|
|
|
|
template = """Translate this Chinese text to English:
|
|
{}
|
|
---
|
|
Translation:
|
|
{}"""
|
|
|
|
inputs = tokenizer(
|
|
[template.format(args.input, "")],
|
|
return_tensors="pt",
|
|
).to("cuda")
|
|
|
|
text_streamer = TextStreamer(tokenizer)
|
|
|
|
print("\nGenerating translation...")
|
|
_ = model.generate(
|
|
**inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|