Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved tool calling #7

Merged
merged 6 commits into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/main/kotlin/dev/gabrielolv/kaia/core/HandoffManager.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package dev.gabrielolv.kaia.core

import dev.gabrielolv.kaia.llm.LLMMessage
import dev.gabrielolv.kaia.utils.nextThreadId
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onEach
import java.util.concurrent.ConcurrentHashMap

/**
Expand Down Expand Up @@ -72,23 +75,22 @@ class HandoffManager(val orchestrator: Orchestrator, private val handoffAgentId:
suspend fun sendMessage(
conversationId: String,
message: Message
): Message? {
): Flow<LLMMessage>? {
assert(message.content.isNotBlank()) { "Message cannot be blank" }
assert(message.sender.isNotBlank()) { "Message sender cannot be empty" }

val conversation = conversations[conversationId]
?: return null

val response = orchestrator.processWithAgent(handoffAgentId, message)
conversation.messages.add(response)

return response
return orchestrator.processWithAgent(handoffAgentId, message).onEach { response ->
conversation.messages.add(response)
}
}

/**
* Get the conversation history
*/
fun getHistory(conversationId: String): List<Message>? {
fun getHistory(conversationId: String): List<LLMMessage>? {
return conversations[conversationId]?.messages?.toList()
}

Expand All @@ -106,7 +108,7 @@ class HandoffManager(val orchestrator: Orchestrator, private val handoffAgentId:
data class Conversation(
val id: String,
var currentAgentId: String,
val messages: MutableList<Message>,
val messages: MutableList<LLMMessage>,
val handoffs: MutableList<Handoff> = mutableListOf()
)

Expand Down
20 changes: 9 additions & 11 deletions src/main/kotlin/dev/gabrielolv/kaia/core/Orchestrator.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package dev.gabrielolv.kaia.core

import dev.gabrielolv.kaia.core.agents.Agent
import dev.gabrielolv.kaia.llm.LLMMessage
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow

Expand Down Expand Up @@ -41,30 +41,28 @@ class Orchestrator(
/**
* Process a message through a specific agent
*/
suspend fun processWithAgent(agentId: String, message: Message): Message {
suspend fun processWithAgent(agentId: String, message: Message): Flow<LLMMessage> {
val agent = agents[agentId] ?: throw IllegalArgumentException("Agent $agentId not found")
return agent.process(message)
}

/**
* Send a message to multiple agents and collect their responses
*/
fun broadcast(message: Message, agentIds: List<String>): Flow<Message> = flow {
val responses = agentIds.map { agentId ->
fun broadcast(message: Message, agentIds: List<String>): Flow<LLMMessage> = flow {
agentIds.map { agentId ->
scope.async {
try {
val agent = agents[agentId] ?: throw IllegalArgumentException("Agent $agentId not found")
agent.process(message.copy(recipient = agentId))
agent.process(message.copy(recipient = agentId)).collect { emit(it) }
} catch (e: Exception) {
Message(
sender = "system",
recipient = "orchestrator",
content = "Error processing message by agent $agentId: ${e.message}"
emit(
LLMMessage.SystemMessage(
content = "Error processing message by agent $agentId: ${e.message}"
)
)
}
}
}

responses.awaitAll().forEach { emit(it) }
}
}
22 changes: 15 additions & 7 deletions src/main/kotlin/dev/gabrielolv/kaia/core/agents/Agent.kt
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
package dev.gabrielolv.kaia.core.agents

import dev.gabrielolv.kaia.core.Message

import dev.gabrielolv.kaia.llm.LLMMessage
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow

/**
* The core Agent interface that defines the basic functionality of an agent.
* Updated to support Flow-based message processing.
*/
interface Agent {
val id: String
val name: String
val description: String

/**
* Process a message and generate a response
* Process a message and return a flow of messages
* This allows for streaming responses and intermediate steps like tool calls
*/
suspend fun process(message: Message): Message
fun process(message: Message): Flow<LLMMessage>

companion object {
/**
Expand All @@ -32,8 +36,11 @@ class AgentBuilder {
var id: String = ""
var name: String = ""
var description: String = ""
var processor: suspend (Message) -> Message = { message ->
Message(content = "Default response to: ${message.content}")

var processor: (Message) -> Flow<LLMMessage> = { message ->
flow {
emit(LLMMessage.SystemMessage(content = "Default response to: ${message.content}"))
}
}

fun build(): Agent = BaseAgent(
Expand All @@ -48,9 +55,10 @@ private class BaseAgent(
override val id: String,
override val name: String,
override val description: String,
private val processor: suspend (Message) -> Message
private val processor: (Message) -> Flow<LLMMessage>
) : Agent {
override suspend fun process(message: Message): Message {

override fun process(message: Message): Flow<LLMMessage> {
return processor(message)
}
}
120 changes: 59 additions & 61 deletions src/main/kotlin/dev/gabrielolv/kaia/core/agents/HandoffAgent.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package dev.gabrielolv.kaia.core.agents

import dev.gabrielolv.kaia.core.HandoffManager
import dev.gabrielolv.kaia.core.Message
import dev.gabrielolv.kaia.llm.LLMMessage
import dev.gabrielolv.kaia.llm.LLMOptions
import dev.gabrielolv.kaia.llm.LLMProvider
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive

/**
* Agent that can automatically decide when to hand off to other agents
* Updated to support Flow-based processing
*/
fun Agent.Companion.withHandoff(
handoffManager: HandoffManager,
Expand Down Expand Up @@ -52,76 +55,71 @@ fun Agent.Companion.withHandoff(
Base your decision on the expertise required to address the user's query.
"""

// Extend the processor with smart handoff capabilities
builder.processor = processor@{ message ->

// First, evaluate if we should hand off
val options = LLMOptions(
systemPrompt = handoffEvalPrompt,
temperature = 0.1 // Low temperature for more deterministic decisions
)

val evaluationResponse = provider.generate(message.content, options)

try {
// Parse the evaluation response
val decision = Json.parseToJsonElement(evaluationResponse.content).jsonObject
val shouldHandoff = decision["handoff"]?.jsonPrimitive?.boolean ?: false

if (shouldHandoff) {
val targetAgentId = decision["targetAgentId"]?.jsonPrimitive?.content ?: ""
val reason = decision["reason"]?.jsonPrimitive?.content ?: "No reason provided"
// Flow-based processor
builder.processor = { message ->
flow {
// First, evaluate if we should hand off
val evaluationOptions = LLMOptions(
systemPrompt = handoffEvalPrompt,
temperature = 0.1 // Low temperature for more deterministic decisions
)

// Create handoff
val success = handoffManager.handoff(
conversationId = conversationId,
targetAgentId = targetAgentId,
reason = reason
)
// Collect the evaluation response
val evaluationMessages = provider.generate(message.content, evaluationOptions).toList()
val evaluationContent = evaluationMessages
.filterIsInstance<LLMMessage.AssistantMessage>()
.firstOrNull()?.content ?: ""

try {
// Parse the evaluation response
val decision = Json.parseToJsonElement(evaluationContent).jsonObject
val shouldHandoff = decision["handoff"]?.jsonPrimitive?.boolean ?: false

if (shouldHandoff) {
val targetAgentId = decision["targetAgentId"]?.jsonPrimitive?.content ?: ""
val reason = decision["reason"]?.jsonPrimitive?.content ?: "No reason provided"

// Create handoff
val success = handoffManager.handoff(
conversationId = conversationId,
targetAgentId = targetAgentId,
reason = reason
)

if (success) {
// Process with new agent
val targetAgent = handoffManager.getConversation(conversationId)?.currentAgentId?.let {
handoffManager.orchestrator.getAgent(it)
if (success) {
// Process with new agent
val targetAgent = handoffManager.getConversation(conversationId)?.currentAgentId?.let {
handoffManager.orchestrator.getAgent(it)
}

if (targetAgent != null) {
// Forward the message flow from the target agent
targetAgent.process(message).collect { emit(it) }
} else {
emit(LLMMessage.SystemMessage(content = "Handoff to agent $targetAgentId successful, but could not process message."))
}
} else {
// Handoff failed, process normally
val options = LLMOptions(systemPrompt = systemPrompt)
val handoffFailedPrefix = "I tried to hand off your request to a more specialized agent, but couldn't. I'll do my best to help.\n\n"

// Collect messages from the provider
provider.generate(message.content, options).collect(::emit)
}

return@processor targetAgent?.process(message) ?: Message(
sender = builder.id,
recipient = message.sender,
content = "Handoff to agent $targetAgentId successful, but could not process message."
)
} else {
// Handoff failed, process normally
// No handoff needed, process normally
val options = LLMOptions(systemPrompt = systemPrompt)
val response = provider.generate(message.content, options)

Message(
sender = builder.id,
recipient = message.sender,
content = "I tried to hand off your request to a more specialized agent, but couldn't. I'll do my best to help.\n\n${response.content}"
)
// Collect messages from the provider
provider.generate(message.content, options).collect(::emit)
}
} else {
// No handoff needed, process normally
} catch (e: Exception) {
// Parsing failed, process normally
val options = LLMOptions(systemPrompt = systemPrompt)
val response = provider.generate(message.content, options)

Message(
sender = builder.id,
recipient = message.sender,
content = response.content
)
// Collect messages from the provider
provider.generate(message.content, options).collect(::emit)
}
} catch (e: Exception) {
// Parsing failed, process normally
val options = LLMOptions(systemPrompt = systemPrompt)
val response = provider.generate(message.content, options)

Message(
sender = builder.id,
recipient = message.sender,
content = response.content
)
}
}

Expand Down
14 changes: 7 additions & 7 deletions src/main/kotlin/dev/gabrielolv/kaia/core/agents/LLMAgent.kt
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
package dev.gabrielolv.kaia.core.agents

import dev.gabrielolv.kaia.core.Message
import dev.gabrielolv.kaia.llm.LLMOptions
import dev.gabrielolv.kaia.llm.LLMProvider

/**
* Creates an agent that uses an LLM provider to generate responses
*/
fun Agent.Companion.llm(
provider: LLMProvider,
systemPrompt: String? = null,
block: AgentBuilder.() -> Unit
): Agent {
val builder = AgentBuilder().apply(block)

// Set up the flow-based processor
builder.processor = processor@{ message ->

val options = LLMOptions(
systemPrompt = systemPrompt,
temperature = 0.7
)

val llmResponse = provider.generate(message.content, options)
return@processor provider.generate(message.content, options)

Message(
sender = builder.id.takeIf { it.isNotEmpty() } ?: "llm-agent",
recipient = message.sender,
content = llmResponse.content
)
}

return builder.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import kotlinx.serialization.json.jsonObject
class ToolManager(private val json: Json = Json) {
private val tools = mutableMapOf<String, Tool>()

var errorHandler: suspend (Tool, ToolResult) -> Unit = { tool, result -> }

/**
* Register a tool
*/
Expand Down Expand Up @@ -37,7 +39,11 @@ class ToolManager(private val json: Json = Json) {
)

return try {
tool.execute(parameters)
val result = tool.execute(parameters)
if (!result.success) {
errorHandler(tool, result)
}
result
} catch (e: Exception) {
ToolResult(
success = false,
Expand Down
4 changes: 3 additions & 1 deletion src/main/kotlin/dev/gabrielolv/kaia/core/tools/ToolResult.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package dev.gabrielolv.kaia.core.tools

import dev.gabrielolv.kaia.core.tools.typed.validation.ValidationError
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class ToolResult(
val success: Boolean,
val result: String,
val metadata: JsonObject = JsonObject(emptyMap())
val metadata: JsonObject = JsonObject(emptyMap()),
val validationErrors: List<ValidationError> = emptyList(),
)
Loading