Skip to content

Commit

Permalink
Merge pull request #395 from MannLabs/show_warnings
Browse files Browse the repository at this point in the history
Show warnings from truncation
  • Loading branch information
JuliaS92 authored Jan 28, 2025
2 parents 57de2c7 + 2c82676 commit d7eba93
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 18 deletions.
30 changes: 18 additions & 12 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Dict

import pandas as pd
Expand Down Expand Up @@ -211,7 +212,7 @@ def llm_config():
st.markdown("##### Prompts generated based on analysis input")
with st.expander("System message", expanded=False):
system_message = st.text_area(
"",
" ",
value=get_system_message(st.session_state[StateKeys.DATASET]),
height=150,
disabled=llm_integration_set_for_model,
Expand All @@ -221,7 +222,7 @@ def llm_config():
with st.expander("Initial prompt", expanded=True):
feature_to_repr_map = st.session_state[StateKeys.DATASET]._feature_to_repr_map
initial_prompt = st.text_area(
"",
" ",
value=get_initial_prompt(
plot_parameters,
list(
Expand Down Expand Up @@ -304,16 +305,12 @@ def llm_chat(
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).

# no. tokens spent
total_tokens = 0
pinned_tokens = 0
for message in llm_integration.get_print_view(show_all=show_all):
messages, total_tokens, pinned_tokens = llm_integration.get_print_view(
show_all=show_all
)
for message in messages:
with st.chat_message(message[MessageKeys.ROLE]):
st.markdown(message[MessageKeys.CONTENT])
tokens = llm_integration.estimate_tokens([message])
if message[MessageKeys.IN_CONTEXT]:
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if (
message[MessageKeys.PINNED]
or not message[MessageKeys.IN_CONTEXT]
Expand All @@ -325,6 +322,7 @@ def llm_chat(
if not message[MessageKeys.IN_CONTEXT]:
token_message += ":x: "
if show_individual_tokens:
tokens = llm_integration.estimate_tokens([message])
token_message += f"*tokens: {str(tokens)}*"
st.markdown(token_message)
for artifact in message[MessageKeys.ARTIFACTS]:
Expand All @@ -342,15 +340,24 @@ def llm_chat(
f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*"
)

if st.session_state[StateKeys.RECENT_CHAT_WARNINGS]:
st.warning("Warnings during last chat completion:")
for warning in st.session_state[StateKeys.RECENT_CHAT_WARNINGS]:
st.warning(str(warning.message).replace("\n", "\n\n"))

if prompt := st.chat_input("Say something"):
with st.chat_message(Roles.USER):
st.markdown(prompt)
if show_individual_tokens:
st.markdown(
f"*tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*"
)
with st.spinner("Processing prompt..."):
with st.spinner("Processing prompt..."), warnings.catch_warnings(
record=True
) as caught_warnings:
llm_integration.chat_completion(prompt)
st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = caught_warnings

st.rerun(scope="fragment")

st.download_button(
Expand Down Expand Up @@ -378,7 +385,6 @@ def llm_chat(
key="show_individual_tokens",
help="Show individual token estimates for each message.",
)

llm_chat(
st.session_state[StateKeys.LLM_INTEGRATION][model_name],
show_all,
Expand Down
4 changes: 4 additions & 0 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def init_session_state() -> None:
if StateKeys.MAX_TOKENS not in st.session_state:
st.session_state[StateKeys.MAX_TOKENS] = 10000

if StateKeys.RECENT_CHAT_WARNINGS not in st.session_state:
st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = []


class StateKeys(metaclass=ConstantsClass):
USER_SESSION_ID = "user_session_id"
Expand All @@ -164,6 +167,7 @@ class StateKeys(metaclass=ConstantsClass):
SELECTED_GENES_DOWN = "selected_genes_down"
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"
MAX_TOKENS = "max_tokens"
RECENT_CHAT_WARNINGS = "recent_chat_warnings"

ORGANISM = "organism" # TODO this is essentially a constant

Expand Down
18 changes: 14 additions & 4 deletions alphastats/llm/llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,27 @@ def _chat_completion_create(self) -> ChatCompletion:
logger.info(".. done")
return result

def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
def get_print_view(
self, show_all=False
) -> Tuple[List[Dict[str, Any]], float, float]:
"""Get a structured view of the conversation history for display purposes."""

print_view = []
total_tokens = 0
pinned_tokens = 0
for message_idx, message in enumerate(self._all_messages):
tokens = self.estimate_tokens([message])
in_context = message in self._messages
if in_context:
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if not show_all and (
message[MessageKeys.ROLE] in [Roles.TOOL, Roles.SYSTEM]
):
continue
if not show_all and MessageKeys.TOOL_CALLS in message:
continue
in_context = message in self._messages

print_view.append(
{
Expand All @@ -395,11 +404,12 @@ def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
MessageKeys.PINNED: message[MessageKeys.PINNED],
}
)
return print_view

return print_view, total_tokens, pinned_tokens

def get_chat_log_txt(self) -> str:
"""Get a chat log in text format for saving. It excludes tool replies, as they are usually also represented in the artifacts."""
messages = self.get_print_view(show_all=True)
messages, _, _ = self.get_print_view(show_all=True)
chatlog = ""
for message in messages:
if message[MessageKeys.ROLE] == Roles.TOOL:
Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def test_handle_function_calls(

def test_get_print_view_default(llm_with_conversation):
"""Test get_print_view with default settings (show_all=False)"""
print_view = llm_with_conversation.get_print_view()
print_view, _, _ = llm_with_conversation.get_print_view()

# Should only include user and assistant messages without tool_calls
assert print_view == [
Expand Down Expand Up @@ -518,7 +518,7 @@ def test_get_print_view_default(llm_with_conversation):

def test_get_print_view_show_all(llm_with_conversation):
"""Test get_print_view with default settings (show_all=True)"""
print_view = llm_with_conversation.get_print_view(show_all=True)
print_view, _, _ = llm_with_conversation.get_print_view(show_all=True)

# Should only include user and assistant messages without tool_calls
assert print_view == [
Expand Down

0 comments on commit d7eba93

Please sign in to comment.