diff --git a/main.py b/main.py index 33dba54..7c20a16 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import os import argparse +import asyncio import gradio as gr from difflib import Differ from string import Template @@ -25,7 +26,7 @@ def find_attached_file(filename, attached_files): return file return None -def echo(message, history, state): +async def echo(message, history, state): attached_file = None if message['files']: @@ -34,7 +35,7 @@ def echo(message, history, state): attached_file = find_attached_file(filename, state["attached_files"]) if attached_file is None: - path_gcp = client.files.upload(path=path_local) + path_gcp = await client.files.upload(path=path_local) state["attached_files"].append({ "name": filename, "path_local": path_local, @@ -52,35 +53,42 @@ def echo(message, history, state): chat_history = chat_history + user_message state['messages'] = chat_history - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=state['messages'], - ) - model_response = response.text - + response_chunks = "" + async for chunk in await client.models.generate_content_stream( + model="gemini-2.0-flash", contents=state['messages'], + ): + response_chunks += chunk.text + # when model generates too fast, Gradio does not respond that in real-time. + await asyncio.sleep(0.1) + yield ( + response_chunks, + state, + state['summary_diff_history'][-1] if len(state['summary_diff_history']) > 1 else "", + state['summary_history'][-1] if len(state['summary_history']) > 1 else "", + gr.Slider( + visible=False if len(state['summary_history']) <= 1 else True, + interactive=False if len(state['summary_history']) <= 1 else True, + ), + ) + # make summary - if state['summary'] != "": - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=[ - Template( - prompt_tmpl['summarization']['prompt'] - ).safe_substitute( - previous_summary=state['summary'], - latest_conversation=str({"user": message['text'], "assistant": model_response}) - ) - ], - config={'response_mime_type': 'application/json', - 'response_schema': SummaryResponses, - }, - ) + response = await client.models.generate_content( + model="gemini-2.0-flash", + contents=[ + Template( + prompt_tmpl['summarization']['prompt'] + ).safe_substitute( + previous_summary=state['summary'], + latest_conversation=str({"user": message['text'], "assistant": response_chunks}) + ) + ], + config={'response_mime_type': 'application/json', + 'response_schema': SummaryResponses, + }, + ) - if state['summary'] != "": - prev_summary = state['summary_history'][-1] - else: - prev_summary = "" + prev_summary = state['summary_history'][-1] if len(state['summary_history']) >= 1 else "" - d = Differ() state['summary'] = ( response.parsed.summary if getattr(response.parsed, "summary", None) is not None @@ -94,14 +102,13 @@ def echo(message, history, state): state['summary_diff_history'].append( [ (token[2:], token[0] if token[0] != " " else None) - for token in d.compare(prev_summary, state['summary']) + for token in Differ().compare(prev_summary, state['summary']) ] ) - return ( - model_response, + yield ( + response_chunks, state, - # state['summary'], state['summary_diff_history'][-1], state['summary_history'][-1], gr.Slider( @@ -166,7 +173,7 @@ def main(args): # value="No summary yet. As you chat with the assistant, the summary will be updated automatically.", combine_adjacent=True, show_legend=True, - color_map={"+": "red", "-": "green"}, + color_map={"-": "red", "+": "green"}, elem_classes=["summary-window"], visible=False ) diff --git a/utils.py b/utils.py index d4b845a..ac823d7 100644 --- a/utils.py +++ b/utils.py @@ -9,14 +9,18 @@ def load_prompt(args): def setup_gemini_client(args): if args.vertexai: - client = genai.Client( - vertexai=args.vertexai, - project=args.vertexai_project, - location=args.vertexai_location + client = genai.client.AsyncClient( + genai.client.ApiClient( + vertexai=args.vertexai, + project=args.vertexai_project, + location=args.vertexai_location + ) ) else: - client = genai.Client( - api_key=args.ai_studio_api_key, + client = genai.client.AsyncClient( + genai.client.ApiClient( + api_key=args.ai_studio_api_key + ) ) return client \ No newline at end of file