Skip to content

Commit

Permalink
openai reinitialization bug fixed. added switch model method.
Browse files Browse the repository at this point in the history
  • Loading branch information
nilix-ba committed Mar 6, 2025
1 parent ff55510 commit 1d41aff
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 18 deletions.
86 changes: 69 additions & 17 deletions src/fundus_murag/assistant/openai_fundus_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fundus_murag.config.config import load_config
from fundus_murag.singleton_meta import SingletonMeta

# Default model name if none provided
OPENAI_MODEL = "gpt-4o-mini"


Expand All @@ -31,58 +32,113 @@ def __init__(self, model_name: str, use_tools: bool = True) -> None:
self._conf = load_config()
openai.api_key = self._conf.open_ai_project_id

# a unified history for both the base and OpenAI implementation.
self.chat_history = []
# ensures both references point to the same list.
self._persistent_history = self.chat_history

self._model_config: ModelConfig = self._create_model_config(
model_name, use_tools
)
self._has_active_session = False

def _get_full_history(self) -> List[dict]:
"""
Returns the full chat history with the system instruction at the beginning.
"""
full_history = []
if self._model_config["system_instruction"]:
full_history.append(
{"role": "system", "content": self._model_config["system_instruction"]}
)
full_history.extend(self._persistent_history)
return full_history

def switch_model(self, new_model_name: str):
"""
switches to a new model by updating the model configuration and resetting chat history.
"""
if new_model_name != self.model_name:
# logger.info(f"Switching model from {self.model_name} to {new_model_name}")
self.model_name = new_model_name
self._model_config = self._create_model_config(
new_model_name, self.use_tools
)
self.reset_chat_session()

def _send_text_message_to_model(self, prompt: str) -> Any:
logger.info(f"Sending text prompt to OpenAI: {prompt}")

self._persistent_history.append({"role": "user", "content": prompt})

try:
response = openai.chat.completions.create(
model=self._model_config["model_name"],
messages=self.chat_history,
messages=self._get_full_history(), # type: ignore
temperature=self._model_config["temperature"],
)
message_content = response.choices[0].message.content
self.chat_history.append({"role": "assistant", "content": message_content})
self._persistent_history.append(
{"role": "assistant", "content": message_content}
)
logger.info(f"OpenAI - Text response: {message_content}")
return response
except openai.error.OpenAIError as e: # type: ignore
logger.error(f"OpenAI - OpenAIError during text prompt handling: {e}")
return {"error": str(e)}
except Exception as e:
logger.error(f"OpenAI - Error during text prompt handling: {e}")
logger.error(f"OpenAI - Unexpected error during text prompt handling: {e}")
return {"error": str(e)}

def _extract_text_from_response(self, raw_response: dict) -> str:
def _extract_text_from_response(self, raw_response: Any) -> str:
try:
return raw_response["choices"][0]["message"]["content"]
except Exception:
return raw_response.choices[0].message.content
except Exception as e:
logger.error(f"Error extracting text from response: {e}")
return ""

def _start_new_chat_session(self) -> None:
"""
Resets the conversation history and flags a new session.
"""
logger.info("Starting new OpenAI chat session.")
self.reset_chat_session()
self._has_active_session = True

def reset_chat_session(self) -> None:
# logger.info(f"Resetting OpenAI chat session. History before reset: {self._persistent_history}")
super().reset_chat_session()
self._persistent_history.clear()
self._has_active_session = False
# logger.info(f"Chat history after reset: {self._persistent_history}")

def _chat_session_active(self) -> bool:
return self._has_active_session

def _send_followup_message_to_model(self, content: Any) -> Any:
logger.info(
f"OpenAI - Sending follow-up content after function call: {content}"
)
self.chat_history.append({"role": "assistant", "content": str(content)})

self._persistent_history.append({"role": "assistant", "content": str(content)})

try:
response = openai.chat.completions.create(
model=self._model_config["model_name"],
messages=self.chat_history,
messages=self._get_full_history(), # type: ignore
temperature=self._model_config["temperature"],
)
message_content = response.choices[0].message.content
self.chat_history.append({"role": "assistant", "content": message_content})
self._persistent_history.append(
{"role": "assistant", "content": message_content}
)
logger.info(f"OpenAI - Followup response: {message_content}")
return response
except openai.error.OpenAIError as e: # type: ignore
logger.error(f"OpenAI - OpenAIError during followup handling: {e}")
return {"error": str(e)}
except Exception as e:
logger.error(f"OpenAI - Error during followup handling: {e}")
logger.error(f"OpenAI - Unexpected error during followup handling: {e}")
return {"error": str(e)}

def send_text_image_message(
Expand All @@ -92,10 +148,8 @@ def send_text_image_message(

def _create_model_config(self, model_name: str, use_tools: bool) -> ModelConfig:
model_config = BASE_MODEL_CONFIG.copy()

model_config["model_name"] = model_name or OPENAI_MODEL
model_config["system_instruction"] = ASSISTANT_SYSTEM_INSTRUCTION

return model_config

@staticmethod
Expand Down Expand Up @@ -127,11 +181,9 @@ def list_available_models(only_gpt: bool = False) -> pd.DataFrame:
data["output_token_limit"].append(None)

df = pd.DataFrame(data)
logger.info(
f"Found {len(df)} models matching initial filter (only_gpt={only_gpt})."
)
# logger.info(f"Found {len(df)} models matching initial filter (only_gpt={only_gpt}).")

# models that can process images
# Apply additional filters for photo-processing capable models.
photo_model_substrings = ["o1", "gpt-4"]
exclude_substrings = ["audio", "realtime", "preview", "turbo"]

Expand All @@ -143,5 +195,5 @@ def list_available_models(only_gpt: bool = False) -> pd.DataFrame:
& ~df["name"].str.contains(exclude_pattern, case=False)
]

logger.info(f"After photo-processing filter, {len(df_filtered)} models remain.")
# logger.info(f"After photo-processing filter, {len(df_filtered)} models remain.")
return df_filtered
6 changes: 5 additions & 1 deletion src/fundus_murag/ui/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,8 @@ def get_assistant_instance(model_name: str, merged_df: pd.DataFrame):
if model_source == "gemini":
return GeminiFundusAssistant(model_name)
else:
return OpenAIFundusAssistant(model_name)
assistant = OpenAIFundusAssistant(model_name)
# If the current instance's model does not match the requested model, switch it.
if assistant.model_name != model_name:
assistant.switch_model(model_name)
return assistant

0 comments on commit 1d41aff

Please sign in to comment.