Skip to content

Commit 9a876a9

Browse files
committed
Add tests for StructuredAgent
1 parent de1714c commit 9a876a9

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/agents/test_structured_agent.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import AsyncIterator
2+
3+
from coagent.agents import ChatHistory, ChatMessage
4+
from coagent.agents.structured_agent import StructuredAgent
5+
from coagent.core import Context, GenericMessage, Message
6+
from pydantic import BaseModel
7+
import pytest
8+
9+
10+
class Input(Message):
11+
role: str = ""
12+
content: str = ""
13+
14+
15+
class Output(BaseModel):
16+
content: str = ""
17+
18+
19+
class MockAgent(StructuredAgent):
20+
async def _handle_history(
21+
self,
22+
msg: ChatHistory,
23+
response_format: dict | None = None,
24+
) -> AsyncIterator[ChatMessage]:
25+
if response_format and response_format["json_schema"]["name"] == "Output":
26+
out = Output(content="Hello!")
27+
yield ChatMessage(role="assistant", content=out.model_dump_json())
28+
else:
29+
yield ChatMessage(role="assistant", content="Hello!")
30+
31+
32+
class TestStructuredAgent:
33+
@pytest.mark.asyncio
34+
async def test_render_system(self):
35+
agent = StructuredAgent(
36+
input_type=Input, system="You are a helpful {{ role }}."
37+
)
38+
39+
system = await agent.render_system(Input(role="Translator"))
40+
assert system == "You are a helpful Translator."
41+
42+
@pytest.mark.asyncio
43+
async def test_render_messages(self):
44+
agent = StructuredAgent(
45+
input_type=Input,
46+
messages=[ChatMessage(role="user", content="{{ content }}")],
47+
)
48+
49+
messages = await agent.render_messages(Input(content="Hello"))
50+
assert messages == [ChatMessage(role="user", content="Hello")]
51+
52+
@pytest.mark.asyncio
53+
async def test_handle_input(self):
54+
agent = MockAgent(
55+
input_type=Input,
56+
)
57+
58+
# Success
59+
_input = GenericMessage.decode(Input().encode())
60+
async for msg in agent.handle(_input, Context()):
61+
assert msg.content == "Hello!"
62+
63+
# Error
64+
_input = GenericMessage.decode(ChatMessage(role="").encode())
65+
with pytest.raises(ValueError) as exc:
66+
async for _ in agent.handle(_input, Context()):
67+
pass
68+
assert "Invalid message type" in str(exc.value)
69+
70+
@pytest.mark.asyncio
71+
async def test_handle_output(self):
72+
agent = MockAgent(
73+
input_type=Input,
74+
output_type=Output,
75+
)
76+
77+
_input = GenericMessage.decode(Input().encode())
78+
async for msg in agent.handle(_input, Context()):
79+
assert msg.content == '{"content":"Hello!"}'

0 commit comments

Comments
 (0)