chore: add inference
This commit is contained in:
53
inference.py
Normal file
53
inference.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import argparse
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from transformers import TextStreamer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Run inference with the model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-i", "--input", type=str, required=True, help="Chinese text to translate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--src",
|
||||||
|
type=str,
|
||||||
|
default="/workspace/output",
|
||||||
|
help="Path to the checkpoint directory.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
max_seq_length = 6144
|
||||||
|
dtype = None
|
||||||
|
load_in_4bit = False
|
||||||
|
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=args.src,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
dtype=dtype,
|
||||||
|
load_in_4bit=load_in_4bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
FastLanguageModel.for_inference(model)
|
||||||
|
|
||||||
|
template = """Translate this Chinese text to English:
|
||||||
|
{}
|
||||||
|
---
|
||||||
|
Translation:
|
||||||
|
{}"""
|
||||||
|
|
||||||
|
inputs = tokenizer(
|
||||||
|
[template.format(args.input, "")],
|
||||||
|
return_tensors="pt",
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
text_streamer = TextStreamer(tokenizer)
|
||||||
|
|
||||||
|
print("\nGenerating translation...")
|
||||||
|
_ = model.generate(
|
||||||
|
**inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user