Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 10, 2024
1 parent c38b679 commit d2e24f6
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,12 @@ def chat(
model=self.model_name, messages=fixed_chat, **tmp_kwargs # type: ignore
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content # type: ignore
yield chunk_message

return f()
else:
return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -192,10 +194,12 @@ def generate(
model=self.model_name, messages=message, **tmp_kwargs # type: ignore
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content # type: ignore
yield chunk_message

return f()
else:
return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -371,6 +375,7 @@ def chat(
data.update(tmp_kwargs)
json_data = json.dumps(data)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
with requests.post(url, data=json_data, stream=True) as stream:
if stream.status_code != 200:
Expand All @@ -384,6 +389,7 @@ def f() -> Iterator[Optional[str]]:
yield None
else:
yield chunk_data["message"]["content"]

return f()
else:
stream = requests.post(url, data=json_data)
Expand Down Expand Up @@ -416,6 +422,7 @@ def generate(
data.update(tmp_kwargs)
json_data = json.dumps(data)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
with requests.post(url, data=json_data, stream=True) as stream:
if stream.status_code != 200:
Expand All @@ -429,6 +436,7 @@ def f() -> Iterator[Optional[str]]:
yield None
else:
yield chunk_data["response"]

return f()
else:
stream = requests.post(url, data=json_data)
Expand Down Expand Up @@ -498,14 +506,19 @@ def chat(
model=self.model_name, messages=messages, **tmp_kwargs
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
for chunk in response:
if chunk.type == "message_start" or chunk.type == "content_block_start":
if (
chunk.type == "message_start"
or chunk.type == "content_block_start"
):
continue
elif chunk.type == "content_block_delta":
yield chunk.delta.text
elif chunk.type == "message_stop":
yield None

return f()
else:
return cast(str, response.content[0].text)
Expand Down Expand Up @@ -541,14 +554,19 @@ def generate(
**tmp_kwargs,
)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:

def f() -> Iterator[Optional[str]]:
for chunk in response:
if chunk.type == "message_start" or chunk.type == "content_block_start":
if (
chunk.type == "message_start"
or chunk.type == "content_block_start"
):
continue
elif chunk.type == "content_block_delta":
yield chunk.delta.text
elif chunk.type == "message_stop":
yield None

return f()
else:
return cast(str, response.content[0].text)

0 comments on commit d2e24f6

Please sign in to comment.