Skip to content

Commit

Permalink
fixed the function calls bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nilix-ba committed Mar 6, 2025
1 parent 1d41aff commit 35e84de
Showing 1 changed file with 82 additions and 8 deletions.
90 changes: 82 additions & 8 deletions src/fundus_murag/assistant/gemini_fundus_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
GenerationConfig,
GenerationResponse,
GenerativeModel,
Part,
)

from fundus_murag.assistant.base_fundus_assistant import BaseFundusAssistant
Expand Down Expand Up @@ -49,7 +50,8 @@ def __init__(self, model_name: str, use_tools: bool = True):
def _send_text_message_to_model(self, prompt: str) -> GenerationResponse:
logger.info(f"Prompt: {prompt}")
response = self._chat_session.send_message(prompt) # type: ignore
logger.info(f"Text response received: {response}")
logger.info("Text response received from Gemini.")
self._print_text_response(response)
return response

def _extract_text_from_response(self, raw_response: GenerationResponse) -> str:
Expand All @@ -68,17 +70,90 @@ def reset_chat_session(self) -> None:
def _chat_session_active(self) -> bool:
return self._chat_session is not None

def _send_followup_message_to_model(self, content: Any) -> Any:
logger.info(f"Function call executed. Sending result back: {content}")
return self._chat_session.send_message(content) # type: ignore
def _is_text_response(self, response: GenerationResponse) -> bool:
try:
return response.candidates[0].text is not None
except Exception:
return False

def _is_function_call_response(self, response: GenerationResponse) -> bool:
try:
return len(response.candidates[0].function_calls) > 0
except Exception:
return False

def _execute_function_call(self, response: GenerationResponse) -> Part:
try:
self._print_function_call(response)
function_call = response.candidates[0].content.parts[0].function_call
params = dict(function_call.args)
res = self._function_call_handler.execute_function(
name=function_call.name,
convert_results_to_json=True,
**params,
)
part = Part.from_function_response(
name=function_call.name,
response={"content": res},
)
self._print_function_call_result(part)
return part
except Exception as e:
logger.error(f"Error executing function call: {e}")
return Part.from_function_response(
name="Error",
response={"content": str(e)},
)

def _send_followup_message_to_model(self, content: Any) -> GenerationResponse:
logger.info("Sending function call result back to Gemini.")
response = self._chat_session.send_message(content) # type: ignore
self._print_text_response(response)
return response

def _handle_function_calls(
self, response: GenerationResponse
) -> GenerationResponse:
while self._is_function_call_response(response):
part = self._execute_function_call(response)
response = self._send_followup_message_to_model(part)
return response

def _print_function_call(self, response: GenerationResponse) -> None:
logger.info("*** Function Call Detected ***")
function_call = response.candidates[0].content.parts[0].function_call
logger.info(f"Function Name: {function_call.name}")
truncated_args = {}
for key, val in function_call.args.items():
if isinstance(val, str) and len(val) > 256:
val = val[:256] + "..."
truncated_args[key] = val
logger.info(f"Function Args: {truncated_args}")
logger.info("*" * 120)

def _print_function_call_result(self, result: Part) -> None:
logger.info("*** Function Call Response ***")
logger.info(result.to_dict())
logger.info("*" * 120)

def _print_text_response(self, response: GenerationResponse) -> None:
logger.info("+++ Text Response +++")
text = ""
try:
text = response.candidates[0].text or ""
except ValueError:
# Gemini raises ValueError if .text is not available
logger.info("No text available in this response.")
logger.info(text)
logger.info("+" * 120)

@logger.catch
def send_text_image_message(
self, text_prompt: str, base64_images: list[str], reset_chat: bool = False
) -> GenerationResponse:
raise NotImplementedError

def _load_model(self, model_name: str, use_tools: bool) -> GenerativeModel:
@staticmethod
def _load_model(model_name: str, use_tools: bool) -> GenerativeModel:
model_name = model_name.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]
Expand All @@ -99,6 +174,7 @@ def list_available_models(only_flash: bool = False) -> pd.DataFrame:
genai.configure(credentials=creds)
models = []
for m in genai.list_models():
# Filter out any non-Gemini or dev/tuning versions
if (
"gemini" not in m.name
or "1.5" not in m.name
Expand All @@ -118,8 +194,6 @@ def list_available_models(only_flash: bool = False) -> pd.DataFrame:
"output_token_limit": [],
}
for model in models:
# N = model.name
# data["name"].append(N[7:])
data["name"].append(model.name)
data["display_name"].append(model.display_name)
data["input_token_limit"].append(model.input_token_limit)
Expand Down

0 comments on commit 35e84de

Please sign in to comment.