From fda90257d24ca2b354c65ee45aa02a65b111d163 Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Mon, 9 Sep 2024 16:43:45 -0400 Subject: [PATCH] ok ready to merge --- .../app/Functions/CreateTasksTool.php | 7 +- .../tests/Feature/CreateTasksToolTest.php | 19 ++-- app/Domains/Projects/Orchestrate.php | 10 +-- app/Http/Controllers/ProjectController.php | 1 + app/Models/Message.php | 5 ++ resources/js/Layouts/AppLayout.vue | 9 +- resources/js/Layouts/MainMenu.vue | 1 + resources/js/Pages/Projects/Create.vue | 2 +- tests/Feature/ProjectOrchestrateTest.php | 11 ++- tests/fixtures/create_tasks_args.json | 88 +++++++++++++++++++ 10 files changed, 127 insertions(+), 26 deletions(-) create mode 100644 tests/fixtures/create_tasks_args.json diff --git a/Modules/LlmDriver/app/Functions/CreateTasksTool.php b/Modules/LlmDriver/app/Functions/CreateTasksTool.php index 96fb5b76..0137aea8 100644 --- a/Modules/LlmDriver/app/Functions/CreateTasksTool.php +++ b/Modules/LlmDriver/app/Functions/CreateTasksTool.php @@ -5,6 +5,7 @@ use App\Models\Message; use App\Models\Project; use App\Models\Task; +use App\Models\User; use Facades\App\Domains\Sources\WebSearch\GetPage; use Illuminate\Support\Facades\Log; use LlmLaraHub\LlmDriver\Responses\FunctionResponse; @@ -27,11 +28,11 @@ class CreateTasksTool extends FunctionContract public function handle( - Message $message, - array $args = []): FunctionResponse + Message $message): FunctionResponse { Log::info('TaskTool called'); + $args = $message->args; foreach (data_get($args, 'tasks', []) as $taskArg) { $name = data_get($taskArg, 'name', null); $details = data_get($taskArg, 'details', null); @@ -49,7 +50,7 @@ public function handle( 'details' => $details, 'due_date' => $due_date, 'assistant' => $assistant, - 'user_id' => $user_id, + 'user_id' => ($user_id !== "" && User::whereId($user_id)->exists()) ? $user_id : null, ]); } diff --git a/Modules/LlmDriver/tests/Feature/CreateTasksToolTest.php b/Modules/LlmDriver/tests/Feature/CreateTasksToolTest.php index 3c36690c..23707cb8 100644 --- a/Modules/LlmDriver/tests/Feature/CreateTasksToolTest.php +++ b/Modules/LlmDriver/tests/Feature/CreateTasksToolTest.php @@ -30,27 +30,24 @@ public function test_generates_tasks(): void 'chatable_type' => Project::class, ]); - $message = Message::factory()->create([ - 'chat_id' => $chat->id, - ]); - - $data = get_fixture('claude_chat_response.json'); $data = data_get($data, 'tool_calls.1.arguments.tasks'); - $this->assertDatabaseCount('tasks', 0); - - - (new CreateTasksTool())->handle($message, [ - 'tasks' => $data, + $message = Message::factory()->create([ + 'chat_id' => $chat->id, + 'args' => [ + 'tasks' => $data, + ] ]); + $this->assertDatabaseCount('tasks', 0); + + (new CreateTasksTool())->handle($message); $this->assertDatabaseCount('tasks', 5); $this->assertCount(5, $project->refresh()->tasks); } - } diff --git a/app/Domains/Projects/Orchestrate.php b/app/Domains/Projects/Orchestrate.php index 409c686f..f85d7604 100644 --- a/app/Domains/Projects/Orchestrate.php +++ b/app/Domains/Projects/Orchestrate.php @@ -37,16 +37,16 @@ public function handle(Chat $chat, string $prompt, string $systemPrompt = ''): v 'tool_count' => count($response->tool_calls), ]); - $tool = app()->make($tool_call->name); - $tool->handle($chat, $tool_call->arguments); - - $chat->addInputWithTools( - message: sprintf('Tool %s complete', $tool_call->name), + $message = $chat->addInputWithTools( + message: sprintf('Tool %s', $tool_call->name), tool_id: $tool_call->id, tool_name: $tool_call->name, tool_args: $tool_call->arguments, ); + $tool = app()->make($tool_call->name); + $tool->handle($message, $tool_call->arguments); + $count++; } diff --git a/app/Http/Controllers/ProjectController.php b/app/Http/Controllers/ProjectController.php index 54ad5cb2..dcea362b 100644 --- a/app/Http/Controllers/ProjectController.php +++ b/app/Http/Controllers/ProjectController.php @@ -94,6 +94,7 @@ public function showWithChat(Project $project, Chat$chat) { 'chat' => new ChatResource($chat), 'messages' => MessageResource::collection($chat->messages() ->notSystem() + ->notTool() ->latest() ->paginate(3)), ]); diff --git a/app/Models/Message.php b/app/Models/Message.php index 2e4ca103..a7647e0d 100644 --- a/app/Models/Message.php +++ b/app/Models/Message.php @@ -71,6 +71,11 @@ public function scopeNotSystem(Builder $query) } + public function scopeNotTool(Builder $query) + { + return $query->where('role', '!=', RoleEnum::Tool->value); + } + /** * Return a compressed message */ diff --git a/resources/js/Layouts/AppLayout.vue b/resources/js/Layouts/AppLayout.vue index 84714f72..1e86c452 100644 --- a/resources/js/Layouts/AppLayout.vue +++ b/resources/js/Layouts/AppLayout.vue @@ -41,12 +41,13 @@ const theme = ref('dark')