@@ -26,6 +26,8 @@ def __init__(
26
26
eos : str ,
27
27
n_generate : int ,
28
28
temperature : float = 0.8 ,
29
+ top_p : float = 0.9 ,
30
+ top_k : int = 40 ,
29
31
agent_prefixes : dict [str , str ] = { # Default template (most common one)
30
32
SYSTEM_KEY : '<|im_start|>system' ,
31
33
ASSISTANT_KEY : '<|im_start|>assistant' ,
@@ -46,6 +48,8 @@ def __init__(
46
48
@param eos: the token that ends a single chat round
47
49
@param n_generate: the maximum number of tokens generated by the model in a single turn
48
50
@param temperature: the temperature used for model inference
51
+ @param top_p: the top_p used for model inference
52
+ @param top_k: the top_k used for model inference
49
53
@param agent_prefixes: the tokens used to wrap an agent name
50
54
@param agent_names: the dict with the names for: system, assistant, user
51
55
@param debug: whether or not to output debug informations
@@ -55,6 +59,8 @@ def __init__(
55
59
self .eos = eos
56
60
self .n_generate = n_generate
57
61
self .temperature = temperature
62
+ self .top_p = top_p
63
+ self .top_k = top_k
58
64
self .agent_prefixes = agent_prefixes
59
65
self .agent_names = agent_names
60
66
self .debug = debug
@@ -78,7 +84,7 @@ def generate_assistant_reply(self, grammar: LlamaGrammar | None = None) -> tuple
78
84
79
85
reply = ''
80
86
n_reply_tokens = 0
81
- for token in self .model .generate (tokens = self .tokens_cache , temp = self .temperature , grammar = grammar ):
87
+ for token in self .model .generate (tokens = self .tokens_cache , temp = self .temperature , top_p = self . top_p , top_k = self . top_k , grammar = grammar ):
82
88
self .check_context_overflow () # Check for context exceeded
83
89
if token == self .model .token_eos () or token == self .eos_token : # Check for EOS termination
84
90
self .tokens_cache .append (self .eos_token )
@@ -114,7 +120,7 @@ def generate_assistant_reply_stepped(self, grammar: LlamaGrammar | None = None):
114
120
115
121
reply = ''
116
122
n_reply_tokens = 0
117
- for token in self .model .generate (tokens = self .tokens_cache , temp = self .temperature , grammar = grammar ):
123
+ for token in self .model .generate (tokens = self .tokens_cache , temp = self .temperature , top_p = self . top_p , top_k = self . top_k , grammar = grammar ):
118
124
self .check_context_overflow ()
119
125
if token == self .model .token_eos () or token == self .eos_token : # Check for EOS termination
120
126
self .tokens_cache .append (self .eos_token )
0 commit comments