From d2e24f621d6d857cde221ff452ca467086e7b8f7 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 9 Aug 2024 20:39:14 -0700 Subject: [PATCH] black formatting --- vision_agent/lmm/lmm.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index f9470b71..9a8c5bf1 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -151,10 +151,12 @@ def chat( model=self.model_name, messages=fixed_chat, **tmp_kwargs # type: ignore ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: for chunk in response: chunk_message = chunk.choices[0].delta.content # type: ignore yield chunk_message + return f() else: return cast(str, response.choices[0].message.content) @@ -192,10 +194,12 @@ def generate( model=self.model_name, messages=message, **tmp_kwargs # type: ignore ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: for chunk in response: chunk_message = chunk.choices[0].delta.content # type: ignore yield chunk_message + return f() else: return cast(str, response.choices[0].message.content) @@ -371,6 +375,7 @@ def chat( data.update(tmp_kwargs) json_data = json.dumps(data) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: with requests.post(url, data=json_data, stream=True) as stream: if stream.status_code != 200: @@ -384,6 +389,7 @@ def f() -> Iterator[Optional[str]]: yield None else: yield chunk_data["message"]["content"] + return f() else: stream = requests.post(url, data=json_data) @@ -416,6 +422,7 @@ def generate( data.update(tmp_kwargs) json_data = json.dumps(data) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: with requests.post(url, data=json_data, stream=True) as stream: if stream.status_code != 200: @@ -429,6 +436,7 @@ def f() -> Iterator[Optional[str]]: yield None else: yield chunk_data["response"] + return f() else: stream = requests.post(url, data=json_data) @@ -498,14 +506,19 @@ def chat( model=self.model_name, messages=messages, **tmp_kwargs ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: for chunk in response: - if chunk.type == "message_start" or chunk.type == "content_block_start": + if ( + chunk.type == "message_start" + or chunk.type == "content_block_start" + ): continue elif chunk.type == "content_block_delta": yield chunk.delta.text elif chunk.type == "message_stop": yield None + return f() else: return cast(str, response.content[0].text) @@ -541,14 +554,19 @@ def generate( **tmp_kwargs, ) if "stream" in tmp_kwargs and tmp_kwargs["stream"]: + def f() -> Iterator[Optional[str]]: for chunk in response: - if chunk.type == "message_start" or chunk.type == "content_block_start": + if ( + chunk.type == "message_start" + or chunk.type == "content_block_start" + ): continue elif chunk.type == "content_block_delta": yield chunk.delta.text elif chunk.type == "message_stop": yield None + return f() else: return cast(str, response.content[0].text)