@@ -11,16 +11,10 @@ def parse_inputs():
11
11
default = "01-ai/Yi-6B" ,
12
12
help = "pretrained model path locally or name on huggingface" ,
13
13
)
14
- parser .add_argument (
15
- "--tokenizer" ,
16
- type = str ,
17
- default = "" ,
18
- help = "tokenizer path locally or name on huggingface" ,
19
- )
20
14
parser .add_argument (
21
15
"--max-tokens" ,
22
16
type = int ,
23
- default = 512 ,
17
+ default = 256 ,
24
18
help = "max number of tokens to generate" ,
25
19
)
26
20
parser .add_argument (
@@ -34,41 +28,41 @@ def parse_inputs():
34
28
default = "Let me tell you an interesting story about cat Tom and mouse Jerry," ,
35
29
help = "The prompt to start with" ,
36
30
)
37
- parser .add_argument (
38
- "--eos-token" ,
39
- type = str ,
40
- default = "<|endoftext|>" ,
41
- help = "End of sentence token" ,
42
- )
31
+ parser .add_argument ("--cpu" , action = "store_true" , help = "Run demo with CPU only" )
43
32
args = parser .parse_args ()
44
33
return args
45
34
46
35
47
36
def main (args ):
48
37
print (args )
38
+
39
+ if args .cpu :
40
+ device_map = "cpu"
41
+ else :
42
+ device_map = "auto"
43
+
49
44
model = AutoModelForCausalLM .from_pretrained (
50
- args .model , device_map = "auto" , torch_dtype = "auto" , trust_remote_code = True
51
- )
52
- tokenizer = AutoTokenizer .from_pretrained (
53
- args .tokenizer or args .model , trust_remote_code = True
45
+ args .model , device_map = device_map , torch_dtype = "auto"
54
46
)
47
+ tokenizer = AutoTokenizer .from_pretrained (args .model )
55
48
inputs = tokenizer (
56
49
args .prompt ,
57
50
return_tensors = "pt" ,
58
- )
51
+ ).to (model .device )
52
+
59
53
streamer = TextStreamer (tokenizer ) if args .streaming else None
60
54
outputs = model .generate (
61
- inputs . input_ids . cuda () ,
55
+ ** inputs ,
62
56
max_new_tokens = args .max_tokens ,
63
57
streamer = streamer ,
64
- eos_token_id = tokenizer .convert_tokens_to_ids (args .eos_token ),
65
- do_sample = True ,
66
- repetition_penalty = 1.3 ,
67
- no_repeat_ngram_size = 5 ,
68
- temperature = 0.7 ,
69
- top_k = 40 ,
70
- top_p = 0.8 ,
58
+ # do_sample=True,
59
+ # repetition_penalty=1.3,
60
+ # no_repeat_ngram_size=5,
61
+ # temperature=0.7,
62
+ # top_k=40,
63
+ # top_p=0.8,
71
64
)
65
+
72
66
if streamer is None :
73
67
print (tokenizer .decode (outputs [0 ], skip_special_tokens = True ))
74
68
0 commit comments