Skip to content

Commit

Permalink
Add Streaming for LMMs (#191)
Browse files Browse the repository at this point in the history
* added streaming

* fixed type errors

* fixed linting

* fixed generator func type

* black formatting

* fixed tests for streaming
  • Loading branch information
dillonalaird authored Aug 13, 2024
1 parent 62b6137 commit 14f47e0
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 63 deletions.
16 changes: 13 additions & 3 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ def langsmith_wrap_oepnai_mock(request, openai_llm_mock):
@pytest.fixture
def openai_lmm_mock(request):
content = request.param

def mock_generate(*args, **kwargs):
if kwargs.get("stream", False):

def generator():
for chunk in content.split(" ") + [None]:
yield MagicMock(choices=[MagicMock(delta=MagicMock(content=chunk))])

return generator()
else:
return MagicMock(choices=[MagicMock(message=MagicMock(content=content))])

# Note the path here is adjusted to where OpenAI is used, not where it's defined
with patch("vision_agent.lmm.lmm.OpenAI") as mock:
# Setup a mock response structure that matches what your code expects
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content=content))]
)
mock_instance.chat.completions.create.return_value = mock_generate()
yield mock_instance
62 changes: 62 additions & 0 deletions tests/unit/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_generate_with_mock_stream(openai_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OpenAILMM()
response = lmm.generate("test prompt", media=[temp_image], stream=True)
expected_response = ["mocked", "response", None]
for i, chunk in enumerate(response):
assert chunk == expected_response[i]
assert (
"image_url"
in openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][1]
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
Expand All @@ -49,6 +67,23 @@ def test_chat_with_mock(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_chat_with_mock_stream(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
response = lmm.chat([{"role": "user", "content": "test prompt"}], stream=True)
expected_response = ["mocked", "response", None]
for i, chunk in enumerate(response):
assert chunk == expected_response[i]
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
Expand All @@ -73,6 +108,33 @@ def test_call_with_mock(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811
expected_response = ["mocked", "response", None]
lmm = OpenAILMM()
response = lmm("test prompt", stream=True)
for i, chunk in enumerate(response):
assert chunk == expected_response[i]
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)

response = lmm([{"role": "user", "content": "test prompt"}], stream=True)
for i, chunk in enumerate(response):
assert chunk == expected_response[i]
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
dir=WORKSPACE,
conversation=conversation,
)
return extract_json(orch([{"role": "user", "content": prompt}]))
return extract_json(orch([{"role": "user", "content": prompt}], stream=False)) # type: ignore


def run_code_action(code: str, code_interpreter: CodeInterpreter) -> str:
Expand Down
17 changes: 9 additions & 8 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def write_plans(
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
return extract_json(model.chat(chat))
return extract_json(model(chat, stream=False)) # type: ignore


def pick_plan(
Expand Down Expand Up @@ -160,7 +160,7 @@ def pick_plan(
docstring=tool_info, plans=plan_str, previous_attempts="", media=media
)

code = extract_code(model(prompt))
code = extract_code(model(prompt, stream=False)) # type: ignore
log_progress(
{
"type": "log",
Expand Down Expand Up @@ -211,7 +211,7 @@ def pick_plan(
"code": DefaultImports.prepend_imports(code),
}
)
code = extract_code(model(prompt))
code = extract_code(model(prompt, stream=False)) # type: ignore
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
Expand Down Expand Up @@ -251,7 +251,7 @@ def pick_plan(
tool_output=tool_output_str[:20_000],
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat))
best_plan = extract_json(model(chat, stream=False)) # type: ignore

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
Expand Down Expand Up @@ -286,7 +286,7 @@ def write_code(
feedback=feedback,
)
chat[-1]["content"] = prompt
return extract_code(coder(chat))
return extract_code(coder(chat, stream=False)) # type: ignore


def write_test(
Expand All @@ -310,7 +310,7 @@ def write_test(
media=media,
)
chat[-1]["content"] = prompt
return extract_code(tester(chat))
return extract_code(tester(chat, stream=False)) # type: ignore


def write_and_test_code(
Expand Down Expand Up @@ -439,13 +439,14 @@ def debug_code(
while not success and count < 3:
try:
fixed_code_and_test = extract_json(
debugger(
debugger( # type: ignore
FIX_BUG.format(
code=code,
tests=test,
result="\n".join(result.text().splitlines()[-50:]),
feedback=format_memory(working_memory + new_working_memory),
)
),
stream=False,
)
)
success = True
Expand Down
Loading

0 comments on commit 14f47e0

Please sign in to comment.