import argparse from unsloth import FastLanguageModel from transformers import TextStreamer def main(): parser = argparse.ArgumentParser(description="Run inference with the model.") parser.add_argument( "-i", "--input", type=str, required=True, help="Chinese text to translate." ) parser.add_argument( "-s", "--src", type=str, default="/workspace/output", help="Path to the checkpoint directory.", ) args = parser.parse_args() max_seq_length = 6144 dtype = None load_in_4bit = False model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.src, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=load_in_4bit, ) FastLanguageModel.for_inference(model) template = """Translate this Chinese text to English: {} --- Translation: {}""" inputs = tokenizer( [template.format(args.input, "")], return_tensors="pt", ).to("cuda") text_streamer = TextStreamer(tokenizer) print("\nGenerating translation...") _ = model.generate( **inputs, streamer=text_streamer, max_new_tokens=4096, use_cache=True ) if __name__ == "__main__": main()