|
| 1 | +from typing import AsyncIterator |
| 2 | + |
| 3 | +from coagent.agents import ChatHistory, ChatMessage |
| 4 | +from coagent.agents.structured_agent import StructuredAgent |
| 5 | +from coagent.core import Context, GenericMessage, Message |
| 6 | +from pydantic import BaseModel |
| 7 | +import pytest |
| 8 | + |
| 9 | + |
| 10 | +class Input(Message): |
| 11 | + role: str = "" |
| 12 | + content: str = "" |
| 13 | + |
| 14 | + |
| 15 | +class Output(BaseModel): |
| 16 | + content: str = "" |
| 17 | + |
| 18 | + |
| 19 | +class MockAgent(StructuredAgent): |
| 20 | + async def _handle_history( |
| 21 | + self, |
| 22 | + msg: ChatHistory, |
| 23 | + response_format: dict | None = None, |
| 24 | + ) -> AsyncIterator[ChatMessage]: |
| 25 | + if response_format and response_format["json_schema"]["name"] == "Output": |
| 26 | + out = Output(content="Hello!") |
| 27 | + yield ChatMessage(role="assistant", content=out.model_dump_json()) |
| 28 | + else: |
| 29 | + yield ChatMessage(role="assistant", content="Hello!") |
| 30 | + |
| 31 | + |
| 32 | +class TestStructuredAgent: |
| 33 | + @pytest.mark.asyncio |
| 34 | + async def test_render_system(self): |
| 35 | + agent = StructuredAgent( |
| 36 | + input_type=Input, system="You are a helpful {{ role }}." |
| 37 | + ) |
| 38 | + |
| 39 | + system = await agent.render_system(Input(role="Translator")) |
| 40 | + assert system == "You are a helpful Translator." |
| 41 | + |
| 42 | + @pytest.mark.asyncio |
| 43 | + async def test_render_messages(self): |
| 44 | + agent = StructuredAgent( |
| 45 | + input_type=Input, |
| 46 | + messages=[ChatMessage(role="user", content="{{ content }}")], |
| 47 | + ) |
| 48 | + |
| 49 | + messages = await agent.render_messages(Input(content="Hello")) |
| 50 | + assert messages == [ChatMessage(role="user", content="Hello")] |
| 51 | + |
| 52 | + @pytest.mark.asyncio |
| 53 | + async def test_handle_input(self): |
| 54 | + agent = MockAgent( |
| 55 | + input_type=Input, |
| 56 | + ) |
| 57 | + |
| 58 | + # Success |
| 59 | + _input = GenericMessage.decode(Input().encode()) |
| 60 | + async for msg in agent.handle(_input, Context()): |
| 61 | + assert msg.content == "Hello!" |
| 62 | + |
| 63 | + # Error |
| 64 | + _input = GenericMessage.decode(ChatMessage(role="").encode()) |
| 65 | + with pytest.raises(ValueError) as exc: |
| 66 | + async for _ in agent.handle(_input, Context()): |
| 67 | + pass |
| 68 | + assert "Invalid message type" in str(exc.value) |
| 69 | + |
| 70 | + @pytest.mark.asyncio |
| 71 | + async def test_handle_output(self): |
| 72 | + agent = MockAgent( |
| 73 | + input_type=Input, |
| 74 | + output_type=Output, |
| 75 | + ) |
| 76 | + |
| 77 | + _input = GenericMessage.decode(Input().encode()) |
| 78 | + async for msg in agent.handle(_input, Context()): |
| 79 | + assert msg.content == '{"content":"Hello!"}' |
0 commit comments