diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..61e9ef4 --- /dev/null +++ b/inference.py @@ -0,0 +1,53 @@ +import argparse +from unsloth import FastLanguageModel +from transformers import TextStreamer + + +def main(): + parser = argparse.ArgumentParser(description="Run inference with the model.") + parser.add_argument( + "-i", "--input", type=str, required=True, help="Chinese text to translate." + ) + parser.add_argument( + "-s", + "--src", + type=str, + default="/workspace/output", + help="Path to the checkpoint directory.", + ) + args = parser.parse_args() + + max_seq_length = 6144 + dtype = None + load_in_4bit = False + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.src, + max_seq_length=max_seq_length, + dtype=dtype, + load_in_4bit=load_in_4bit, + ) + + FastLanguageModel.for_inference(model) + + template = """Translate this Chinese text to English: +{} +--- +Translation: +{}""" + + inputs = tokenizer( + [template.format(args.input, "")], + return_tensors="pt", + ).to("cuda") + + text_streamer = TextStreamer(tokenizer) + + print("\nGenerating translation...") + _ = model.generate( + **inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True + ) + + +if __name__ == "__main__": + main()