diff --git a/alphastats/gui/pages/06_LLM.py b/alphastats/gui/pages/06_LLM.py index 76301cd1..50f39fe3 100644 --- a/alphastats/gui/pages/06_LLM.py +++ b/alphastats/gui/pages/06_LLM.py @@ -1,4 +1,5 @@ import os +import warnings from typing import Dict import pandas as pd @@ -211,7 +212,7 @@ def llm_config(): st.markdown("##### Prompts generated based on analysis input") with st.expander("System message", expanded=False): system_message = st.text_area( - "", + " ", value=get_system_message(st.session_state[StateKeys.DATASET]), height=150, disabled=llm_integration_set_for_model, @@ -221,7 +222,7 @@ def llm_config(): with st.expander("Initial prompt", expanded=True): feature_to_repr_map = st.session_state[StateKeys.DATASET]._feature_to_repr_map initial_prompt = st.text_area( - "", + " ", value=get_initial_prompt( plot_parameters, list( @@ -304,16 +305,12 @@ def llm_chat( # Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo). # no. tokens spent - total_tokens = 0 - pinned_tokens = 0 - for message in llm_integration.get_print_view(show_all=show_all): + messages, total_tokens, pinned_tokens = llm_integration.get_print_view( + show_all=show_all + ) + for message in messages: with st.chat_message(message[MessageKeys.ROLE]): st.markdown(message[MessageKeys.CONTENT]) - tokens = llm_integration.estimate_tokens([message]) - if message[MessageKeys.IN_CONTEXT]: - total_tokens += tokens - if message[MessageKeys.PINNED]: - pinned_tokens += tokens if ( message[MessageKeys.PINNED] or not message[MessageKeys.IN_CONTEXT] @@ -325,6 +322,7 @@ def llm_chat( if not message[MessageKeys.IN_CONTEXT]: token_message += ":x: " if show_individual_tokens: + tokens = llm_integration.estimate_tokens([message]) token_message += f"*tokens: {str(tokens)}*" st.markdown(token_message) for artifact in message[MessageKeys.ARTIFACTS]: @@ -342,6 +340,11 @@ def llm_chat( f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*" ) + if st.session_state[StateKeys.RECENT_CHAT_WARNINGS]: + st.warning("Warnings during last chat completion:") + for warning in st.session_state[StateKeys.RECENT_CHAT_WARNINGS]: + st.warning(str(warning.message).replace("\n", "\n\n")) + if prompt := st.chat_input("Say something"): with st.chat_message(Roles.USER): st.markdown(prompt) @@ -349,8 +352,12 @@ def llm_chat( st.markdown( f"*tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*" ) - with st.spinner("Processing prompt..."): + with st.spinner("Processing prompt..."), warnings.catch_warnings( + record=True + ) as caught_warnings: llm_integration.chat_completion(prompt) + st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = caught_warnings + st.rerun(scope="fragment") st.download_button( @@ -378,7 +385,6 @@ def llm_chat( key="show_individual_tokens", help="Show individual token estimates for each message.", ) - llm_chat( st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all, diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index 52690831..28041297 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -145,6 +145,9 @@ def init_session_state() -> None: if StateKeys.MAX_TOKENS not in st.session_state: st.session_state[StateKeys.MAX_TOKENS] = 10000 + if StateKeys.RECENT_CHAT_WARNINGS not in st.session_state: + st.session_state[StateKeys.RECENT_CHAT_WARNINGS] = [] + class StateKeys(metaclass=ConstantsClass): USER_SESSION_ID = "user_session_id" @@ -164,6 +167,7 @@ class StateKeys(metaclass=ConstantsClass): SELECTED_GENES_DOWN = "selected_genes_down" SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields" MAX_TOKENS = "max_tokens" + RECENT_CHAT_WARNINGS = "recent_chat_warnings" ORGANISM = "organism" # TODO this is essentially a constant diff --git a/alphastats/llm/llm_integration.py b/alphastats/llm/llm_integration.py index 89a839ea..2e3ebc82 100644 --- a/alphastats/llm/llm_integration.py +++ b/alphastats/llm/llm_integration.py @@ -373,18 +373,27 @@ def _chat_completion_create(self) -> ChatCompletion: logger.info(".. done") return result - def get_print_view(self, show_all=False) -> List[Dict[str, Any]]: + def get_print_view( + self, show_all=False + ) -> Tuple[List[Dict[str, Any]], float, float]: """Get a structured view of the conversation history for display purposes.""" print_view = [] + total_tokens = 0 + pinned_tokens = 0 for message_idx, message in enumerate(self._all_messages): + tokens = self.estimate_tokens([message]) + in_context = message in self._messages + if in_context: + total_tokens += tokens + if message[MessageKeys.PINNED]: + pinned_tokens += tokens if not show_all and ( message[MessageKeys.ROLE] in [Roles.TOOL, Roles.SYSTEM] ): continue if not show_all and MessageKeys.TOOL_CALLS in message: continue - in_context = message in self._messages print_view.append( { @@ -395,11 +404,12 @@ def get_print_view(self, show_all=False) -> List[Dict[str, Any]]: MessageKeys.PINNED: message[MessageKeys.PINNED], } ) - return print_view + + return print_view, total_tokens, pinned_tokens def get_chat_log_txt(self) -> str: """Get a chat log in text format for saving. It excludes tool replies, as they are usually also represented in the artifacts.""" - messages = self.get_print_view(show_all=True) + messages, _, _ = self.get_print_view(show_all=True) chatlog = "" for message in messages: if message[MessageKeys.ROLE] == Roles.TOOL: diff --git a/tests/llm/test_llm_integration.py b/tests/llm/test_llm_integration.py index 8f6ed9aa..c9745bfe 100644 --- a/tests/llm/test_llm_integration.py +++ b/tests/llm/test_llm_integration.py @@ -481,7 +481,7 @@ def test_handle_function_calls( def test_get_print_view_default(llm_with_conversation): """Test get_print_view with default settings (show_all=False)""" - print_view = llm_with_conversation.get_print_view() + print_view, _, _ = llm_with_conversation.get_print_view() # Should only include user and assistant messages without tool_calls assert print_view == [ @@ -518,7 +518,7 @@ def test_get_print_view_default(llm_with_conversation): def test_get_print_view_show_all(llm_with_conversation): """Test get_print_view with default settings (show_all=True)""" - print_view = llm_with_conversation.get_print_view(show_all=True) + print_view, _, _ = llm_with_conversation.get_print_view(show_all=True) # Should only include user and assistant messages without tool_calls assert print_view == [