chore: add inference
This commit is contained in:
53
inference.py
Normal file
53
inference.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
from unsloth import FastLanguageModel
|
||||
from transformers import TextStreamer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run inference with the model.")
|
||||
parser.add_argument(
|
||||
"-i", "--input", type=str, required=True, help="Chinese text to translate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--src",
|
||||
type=str,
|
||||
default="/workspace/output",
|
||||
help="Path to the checkpoint directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
max_seq_length = 6144
|
||||
dtype = None
|
||||
load_in_4bit = False
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=args.src,
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=dtype,
|
||||
load_in_4bit=load_in_4bit,
|
||||
)
|
||||
|
||||
FastLanguageModel.for_inference(model)
|
||||
|
||||
template = """Translate this Chinese text to English:
|
||||
{}
|
||||
---
|
||||
Translation:
|
||||
{}"""
|
||||
|
||||
inputs = tokenizer(
|
||||
[template.format(args.input, "")],
|
||||
return_tensors="pt",
|
||||
).to("cuda")
|
||||
|
||||
text_streamer = TextStreamer(tokenizer)
|
||||
|
||||
print("\nGenerating translation...")
|
||||
_ = model.generate(
|
||||
**inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user