Skip to content

Dynamic Context Window Size for Ollama Chat #6582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 76 additions & 25 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,41 @@ def total_token_count(self, resp):
except Exception:
pass
return 0


def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total

# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens

# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)

if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192

return ctx_size

class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
Expand Down Expand Up @@ -469,7 +503,7 @@ def chat_streamly(self, system, history, gen_conf):

class OllamaChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"})
self.model_name = model_name

def chat(self, system, history, gen_conf):
Expand All @@ -478,7 +512,12 @@ def chat(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
try:
options = {"num_ctx": 32768}
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)

options = {
"num_ctx": ctx_size
}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
Expand All @@ -489,9 +528,11 @@ def chat(self, system, history, gen_conf):
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1)

response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=10)
ans = response["message"]["content"].strip()
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
return ans, token_count
except Exception as e:
return "**ERROR**: " + str(e), 0

Expand All @@ -500,28 +541,38 @@ def chat_streamly(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans = resp["message"]["content"]
yield ans
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
options = {
"num_ctx": ctx_size
}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]

ans = ""
try:
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 )
for resp in response:
if resp["done"]:
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
yield token_count
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
yield "**ERROR**: " + str(e)
yield 0


class LocalAIChat(Base):
Expand Down