Skip to content

Commit d7eba93

Browse files
authored
Merge pull request #395 from MannLabs/show_warnings
Show warnings from truncation
2 parents 57de2c7 + 2c82676 commit d7eba93

File tree

4 files changed

+38
-18
lines changed

4 files changed

+38
-18
lines changed

alphastats/gui/pages/06_LLM.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Dict
34

45
import pandas as pd
@@ -211,7 +212,7 @@ def llm_config():
211212
st.markdown("##### Prompts generated based on analysis input")
212213
with st.expander("System message", expanded=False):
213214
system_message = st.text_area(
214-
"",
215+
" ",
215216
value=get_system_message(st.session_state[StateKeys.DATASET]),
216217
height=150,
217218
disabled=llm_integration_set_for_model,
@@ -221,7 +222,7 @@ def llm_config():
221222
with st.expander("Initial prompt", expanded=True):
222223
feature_to_repr_map = st.session_state[StateKeys.DATASET]._feature_to_repr_map
223224
initial_prompt = st.text_area(
224-
"",
225+
" ",
225226
value=get_initial_prompt(
226227
plot_parameters,
227228
list(
@@ -304,16 +305,12 @@ def llm_chat(
304305
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).
305306

306307
# no. tokens spent
307-
total_tokens = 0
308-
pinned_tokens = 0
309-
for message in llm_integration.get_print_view(show_all=show_all):
308+
messages, total_tokens, pinned_tokens = llm_integration.get_print_view(
309+
show_all=show_all
310+
)
311+
for message in messages:
310312
with st.chat_message(message[MessageKeys.ROLE]):
311313
st.markdown(message[MessageKeys.CONTENT])
312-
tokens = llm_integration.estimate_tokens([message])
313-
if message[MessageKeys.IN_CONTEXT]:
314-
total_tokens += tokens
315-
if message[MessageKeys.PINNED]:
316-
pinned_tokens += tokens
317314
if (
318315
message[MessageKeys.PINNED]
319316
or not message[MessageKeys.IN_CONTEXT]
@@ -325,6 +322,7 @@ def llm_chat(
325322
if not message[MessageKeys.IN_CONTEXT]:
326323
token_message += ":x: "
327324
if show_individual_tokens:
325+
tokens = llm_integration.estimate_tokens([message])
328326
token_message += f"*tokens: {str(tokens)}*"
329327
st.markdown(token_message)
330328
for artifact in message[MessageKeys.ARTIFACTS]:
@@ -342,15 +340,24 @@ def llm_chat(
342340
f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*"
343341
)
344342

343+
if st.session_state[StateKeys.RECENT_CHAT_WARNINGS]:
344+
st.warning("Warnings during last chat completion:")
345+
for warning in st.session_state[StateKeys.RECENT_CHAT_WARNINGS]:
346+
st.warning(str(warning.message).replace("\n", "\n\n"))
347+
345348
if prompt := st.chat_input("Say something"):
346349
with st.chat_message(Roles.USER):
347350
st.markdown(prompt)
348351
if show_individual_tokens:
349352
st.markdown(
350353
f"*tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*"
351354
)
352-
with st.spinner("Processing prompt..."):
355+
with st.spinner("Processing prompt..."), warnings.catch_warnings(
356+
record=True
357+
) as caught_warnings:
353358
llm_integration.chat_completion(prompt)
359+
st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = caught_warnings
360+
354361
st.rerun(scope="fragment")
355362

356363
st.download_button(
@@ -378,7 +385,6 @@ def llm_chat(
378385
key="show_individual_tokens",
379386
help="Show individual token estimates for each message.",
380387
)
381-
382388
llm_chat(
383389
st.session_state[StateKeys.LLM_INTEGRATION][model_name],
384390
show_all,

alphastats/gui/utils/ui_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def init_session_state() -> None:
145145
if StateKeys.MAX_TOKENS not in st.session_state:
146146
st.session_state[StateKeys.MAX_TOKENS] = 10000
147147

148+
if StateKeys.RECENT_CHAT_WARNINGS not in st.session_state:
149+
st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = []
150+
148151

149152
class StateKeys(metaclass=ConstantsClass):
150153
USER_SESSION_ID = "user_session_id"
@@ -164,6 +167,7 @@ class StateKeys(metaclass=ConstantsClass):
164167
SELECTED_GENES_DOWN = "selected_genes_down"
165168
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"
166169
MAX_TOKENS = "max_tokens"
170+
RECENT_CHAT_WARNINGS = "recent_chat_warnings"
167171

168172
ORGANISM = "organism" # TODO this is essentially a constant
169173

alphastats/llm/llm_integration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,18 +373,27 @@ def _chat_completion_create(self) -> ChatCompletion:
373373
logger.info(".. done")
374374
return result
375375

376-
def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
376+
def get_print_view(
377+
self, show_all=False
378+
) -> Tuple[List[Dict[str, Any]], float, float]:
377379
"""Get a structured view of the conversation history for display purposes."""
378380

379381
print_view = []
382+
total_tokens = 0
383+
pinned_tokens = 0
380384
for message_idx, message in enumerate(self._all_messages):
385+
tokens = self.estimate_tokens([message])
386+
in_context = message in self._messages
387+
if in_context:
388+
total_tokens += tokens
389+
if message[MessageKeys.PINNED]:
390+
pinned_tokens += tokens
381391
if not show_all and (
382392
message[MessageKeys.ROLE] in [Roles.TOOL, Roles.SYSTEM]
383393
):
384394
continue
385395
if not show_all and MessageKeys.TOOL_CALLS in message:
386396
continue
387-
in_context = message in self._messages
388397

389398
print_view.append(
390399
{
@@ -395,11 +404,12 @@ def get_print_view(self, show_all=False) -> List[Dict[str, Any]]:
395404
MessageKeys.PINNED: message[MessageKeys.PINNED],
396405
}
397406
)
398-
return print_view
407+
408+
return print_view, total_tokens, pinned_tokens
399409

400410
def get_chat_log_txt(self) -> str:
401411
"""Get a chat log in text format for saving. It excludes tool replies, as they are usually also represented in the artifacts."""
402-
messages = self.get_print_view(show_all=True)
412+
messages, _, _ = self.get_print_view(show_all=True)
403413
chatlog = ""
404414
for message in messages:
405415
if message[MessageKeys.ROLE] == Roles.TOOL:

tests/llm/test_llm_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_handle_function_calls(
481481

482482
def test_get_print_view_default(llm_with_conversation):
483483
"""Test get_print_view with default settings (show_all=False)"""
484-
print_view = llm_with_conversation.get_print_view()
484+
print_view, _, _ = llm_with_conversation.get_print_view()
485485

486486
# Should only include user and assistant messages without tool_calls
487487
assert print_view == [
@@ -518,7 +518,7 @@ def test_get_print_view_default(llm_with_conversation):
518518

519519
def test_get_print_view_show_all(llm_with_conversation):
520520
"""Test get_print_view with default settings (show_all=True)"""
521-
print_view = llm_with_conversation.get_print_view(show_all=True)
521+
print_view, _, _ = llm_with_conversation.get_print_view(show_all=True)
522522

523523
# Should only include user and assistant messages without tool_calls
524524
assert print_view == [

0 commit comments

Comments
 (0)