Skip to content

Commit 655a134

Browse files
committed
Add top_p and top_k params to chat
1 parent cd35d16 commit 655a134

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

utils/chat.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(
2626
eos: str,
2727
n_generate: int,
2828
temperature: float = 0.8,
29+
top_p: float = 0.9,
30+
top_k: int = 40,
2931
agent_prefixes: dict[str, str] = { # Default template (most common one)
3032
SYSTEM_KEY: '<|im_start|>system',
3133
ASSISTANT_KEY: '<|im_start|>assistant',
@@ -46,6 +48,8 @@ def __init__(
4648
@param eos: the token that ends a single chat round
4749
@param n_generate: the maximum number of tokens generated by the model in a single turn
4850
@param temperature: the temperature used for model inference
51+
@param top_p: the top_p used for model inference
52+
@param top_k: the top_k used for model inference
4953
@param agent_prefixes: the tokens used to wrap an agent name
5054
@param agent_names: the dict with the names for: system, assistant, user
5155
@param debug: whether or not to output debug informations
@@ -55,6 +59,8 @@ def __init__(
5559
self.eos = eos
5660
self.n_generate = n_generate
5761
self.temperature = temperature
62+
self.top_p = top_p
63+
self.top_k = top_k
5864
self.agent_prefixes = agent_prefixes
5965
self.agent_names = agent_names
6066
self.debug = debug
@@ -78,7 +84,7 @@ def generate_assistant_reply(self, grammar: LlamaGrammar | None = None) -> tuple
7884

7985
reply = ''
8086
n_reply_tokens = 0
81-
for token in self.model.generate(tokens=self.tokens_cache, temp=self.temperature, grammar=grammar):
87+
for token in self.model.generate(tokens=self.tokens_cache, temp=self.temperature, top_p=self.top_p, top_k=self.top_k, grammar=grammar):
8288
self.check_context_overflow() # Check for context exceeded
8389
if token == self.model.token_eos() or token == self.eos_token: # Check for EOS termination
8490
self.tokens_cache.append(self.eos_token)
@@ -114,7 +120,7 @@ def generate_assistant_reply_stepped(self, grammar: LlamaGrammar | None = None):
114120

115121
reply = ''
116122
n_reply_tokens = 0
117-
for token in self.model.generate(tokens=self.tokens_cache, temp=self.temperature, grammar=grammar):
123+
for token in self.model.generate(tokens=self.tokens_cache, temp=self.temperature, top_p=self.top_p, top_k=self.top_k, grammar=grammar):
118124
self.check_context_overflow()
119125
if token == self.model.token_eos() or token == self.eos_token: # Check for EOS termination
120126
self.tokens_cache.append(self.eos_token)

0 commit comments

Comments
 (0)