From d873ba6fa565d88f67a90294c4d1416b6acf5fd9 Mon Sep 17 00:00:00 2001 From: chansung Date: Mon, 10 Feb 2025 11:14:41 +0900 Subject: [PATCH 1/3] async & streaming support --- main.py | 72 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 33dba54..8a09949 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import os import argparse +import asyncio import gradio as gr from difflib import Differ from string import Template @@ -25,7 +26,7 @@ def find_attached_file(filename, attached_files): return file return None -def echo(message, history, state): +async def echo(message, history, state): attached_file = None if message['files']: @@ -34,7 +35,7 @@ def echo(message, history, state): attached_file = find_attached_file(filename, state["attached_files"]) if attached_file is None: - path_gcp = client.files.upload(path=path_local) + path_gcp = await client.files.AsyncFiles.upload(path=path_local) state["attached_files"].append({ "name": filename, "path_local": path_local, @@ -52,35 +53,41 @@ def echo(message, history, state): chat_history = chat_history + user_message state['messages'] = chat_history - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=state['messages'], - ) - model_response = response.text - + response_chunks = "" + async for chunk in await client.aio.models.generate_content_stream( + model="gemini-2.0-flash", contents=state['messages'], + ): + response_chunks += chunk.text + await asyncio.sleep(0.1) + yield ( + response_chunks, + state, + state['summary_diff_history'][-1] if len(state['summary_diff_history']) > 1 else "", + state['summary_history'][-1] if len(state['summary_history']) > 1 else "", + gr.Slider( + visible=False if len(state['summary_history']) <= 1 else True, + interactive=False if len(state['summary_history']) <= 1 else True, + ), + ) + # make summary - if state['summary'] != "": - response = client.models.generate_content( - model="gemini-1.5-flash", - contents=[ - Template( - prompt_tmpl['summarization']['prompt'] - ).safe_substitute( - previous_summary=state['summary'], - latest_conversation=str({"user": message['text'], "assistant": model_response}) - ) - ], - config={'response_mime_type': 'application/json', - 'response_schema': SummaryResponses, - }, - ) + response = await client.aio.models.generate_content( + model="gemini-2.0-flash", + contents=[ + Template( + prompt_tmpl['summarization']['prompt'] + ).safe_substitute( + previous_summary=state['summary'], + latest_conversation=str({"user": message['text'], "assistant": response_chunks}) + ) + ], + config={'response_mime_type': 'application/json', + 'response_schema': SummaryResponses, + }, + ) - if state['summary'] != "": - prev_summary = state['summary_history'][-1] - else: - prev_summary = "" + prev_summary = state['summary_history'][-1] if len(state['summary_history']) >= 1 else "" - d = Differ() state['summary'] = ( response.parsed.summary if getattr(response.parsed, "summary", None) is not None @@ -94,14 +101,13 @@ def echo(message, history, state): state['summary_diff_history'].append( [ (token[2:], token[0] if token[0] != " " else None) - for token in d.compare(prev_summary, state['summary']) + for token in Differ().compare(prev_summary, state['summary']) ] ) - return ( - model_response, + yield ( + response_chunks, state, - # state['summary'], state['summary_diff_history'][-1], state['summary_history'][-1], gr.Slider( @@ -166,7 +172,7 @@ def main(args): # value="No summary yet. As you chat with the assistant, the summary will be updated automatically.", combine_adjacent=True, show_legend=True, - color_map={"+": "red", "-": "green"}, + color_map={"-": "red", "+": "green"}, elem_classes=["summary-window"], visible=False ) From 56ff9fe05612c156cade5dca9c3501ce186e965e Mon Sep 17 00:00:00 2001 From: chansung Date: Mon, 10 Feb 2025 11:15:34 +0900 Subject: [PATCH 2/3] async & streaming support --- main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main.py b/main.py index 8a09949..1b65753 100644 --- a/main.py +++ b/main.py @@ -58,6 +58,7 @@ async def echo(message, history, state): model="gemini-2.0-flash", contents=state['messages'], ): response_chunks += chunk.text + # when model generates too fast, Gradio does not respond that in real-time. await asyncio.sleep(0.1) yield ( response_chunks, From ffa237d12f9c66ffc65f7f36e675bdda13f7cc13 Mon Sep 17 00:00:00 2001 From: chansung Date: Mon, 10 Feb 2025 12:12:27 +0900 Subject: [PATCH 3/3] fix: replace genai.Client with genai.client.AsyncClient --- main.py | 6 +++--- utils.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 1b65753..7c20a16 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ async def echo(message, history, state): attached_file = find_attached_file(filename, state["attached_files"]) if attached_file is None: - path_gcp = await client.files.AsyncFiles.upload(path=path_local) + path_gcp = await client.files.upload(path=path_local) state["attached_files"].append({ "name": filename, "path_local": path_local, @@ -54,7 +54,7 @@ async def echo(message, history, state): state['messages'] = chat_history response_chunks = "" - async for chunk in await client.aio.models.generate_content_stream( + async for chunk in await client.models.generate_content_stream( model="gemini-2.0-flash", contents=state['messages'], ): response_chunks += chunk.text @@ -72,7 +72,7 @@ async def echo(message, history, state): ) # make summary - response = await client.aio.models.generate_content( + response = await client.models.generate_content( model="gemini-2.0-flash", contents=[ Template( diff --git a/utils.py b/utils.py index d4b845a..ac823d7 100644 --- a/utils.py +++ b/utils.py @@ -9,14 +9,18 @@ def load_prompt(args): def setup_gemini_client(args): if args.vertexai: - client = genai.Client( - vertexai=args.vertexai, - project=args.vertexai_project, - location=args.vertexai_location + client = genai.client.AsyncClient( + genai.client.ApiClient( + vertexai=args.vertexai, + project=args.vertexai_project, + location=args.vertexai_location + ) ) else: - client = genai.Client( - api_key=args.ai_studio_api_key, + client = genai.client.AsyncClient( + genai.client.ApiClient( + api_key=args.ai_studio_api_key + ) ) return client \ No newline at end of file