Skip to content

Commit 36f9b81

Browse files
committed
update text_generation
1 parent 5ed73fe commit 36f9b81

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

demo/text_generation.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,10 @@ def parse_inputs():
1111
default="01-ai/Yi-6B",
1212
help="pretrained model path locally or name on huggingface",
1313
)
14-
parser.add_argument(
15-
"--tokenizer",
16-
type=str,
17-
default="",
18-
help="tokenizer path locally or name on huggingface",
19-
)
2014
parser.add_argument(
2115
"--max-tokens",
2216
type=int,
23-
default=512,
17+
default=256,
2418
help="max number of tokens to generate",
2519
)
2620
parser.add_argument(
@@ -34,41 +28,41 @@ def parse_inputs():
3428
default="Let me tell you an interesting story about cat Tom and mouse Jerry,",
3529
help="The prompt to start with",
3630
)
37-
parser.add_argument(
38-
"--eos-token",
39-
type=str,
40-
default="<|endoftext|>",
41-
help="End of sentence token",
42-
)
31+
parser.add_argument("--cpu", action="store_true", help="Run demo with CPU only")
4332
args = parser.parse_args()
4433
return args
4534

4635

4736
def main(args):
4837
print(args)
38+
39+
if args.cpu:
40+
device_map = "cpu"
41+
else:
42+
device_map = "auto"
43+
4944
model = AutoModelForCausalLM.from_pretrained(
50-
args.model, device_map="auto", torch_dtype="auto", trust_remote_code=True
51-
)
52-
tokenizer = AutoTokenizer.from_pretrained(
53-
args.tokenizer or args.model, trust_remote_code=True
45+
args.model, device_map=device_map, torch_dtype="auto"
5446
)
47+
tokenizer = AutoTokenizer.from_pretrained(args.model)
5548
inputs = tokenizer(
5649
args.prompt,
5750
return_tensors="pt",
58-
)
51+
).to(model.device)
52+
5953
streamer = TextStreamer(tokenizer) if args.streaming else None
6054
outputs = model.generate(
61-
inputs.input_ids.cuda(),
55+
**inputs,
6256
max_new_tokens=args.max_tokens,
6357
streamer=streamer,
64-
eos_token_id=tokenizer.convert_tokens_to_ids(args.eos_token),
65-
do_sample=True,
66-
repetition_penalty=1.3,
67-
no_repeat_ngram_size=5,
68-
temperature=0.7,
69-
top_k=40,
70-
top_p=0.8,
58+
# do_sample=True,
59+
# repetition_penalty=1.3,
60+
# no_repeat_ngram_size=5,
61+
# temperature=0.7,
62+
# top_k=40,
63+
# top_p=0.8,
7164
)
65+
7266
if streamer is None:
7367
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
7468

0 commit comments

Comments
 (0)