From 755feb8914e56e245ebafb9d3a49b74833082c2d Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Sat, 3 Aug 2024 21:28:55 -0400 Subject: [PATCH] adding Types so I can limit the tools and the confusion that comes with that --- Modules/LlmDriver/app/Functions/Chat.php | 4 ++ .../app/Functions/CreateDocument.php | 5 +++ .../app/Functions/FunctionContract.php | 8 ++++ .../app/Functions/GatherInfoTool.php | 4 ++ .../app/Functions/GetWebSiteFromUrlTool.php | 7 +++ .../LlmDriver/app/Functions/ReportingTool.php | 4 ++ .../app/Functions/RetrieveRelated.php | 4 ++ .../LlmDriver/app/Functions/SearchTheWeb.php | 7 +++ .../app/Functions/StandardsChecker.php | 4 ++ .../app/Functions/SummarizeCollection.php | 4 ++ Modules/LlmDriver/app/Functions/ToolTypes.php | 16 +++++++ Modules/LlmDriver/app/LlmDriverClient.php | 44 ++++++++++++++----- .../tests/Feature/LlmDriverClientTest.php | 17 ++++++- .../Orchestration/OrchestrateVersionTwo.php | 2 + 14 files changed, 118 insertions(+), 12 deletions(-) create mode 100644 Modules/LlmDriver/app/Functions/ToolTypes.php diff --git a/Modules/LlmDriver/app/Functions/Chat.php b/Modules/LlmDriver/app/Functions/Chat.php index 363e5e11..a7698d74 100644 --- a/Modules/LlmDriver/app/Functions/Chat.php +++ b/Modules/LlmDriver/app/Functions/Chat.php @@ -13,6 +13,10 @@ class Chat extends FunctionContract { use ChatHelperTrait, ToolsHelper; + public array $toolTypes = [ + ToolTypes::Chat, + ]; + protected string $name = 'chat_only'; protected string $description = 'User just wants to continue the chat no need to look in the collection for more documents'; diff --git a/Modules/LlmDriver/app/Functions/CreateDocument.php b/Modules/LlmDriver/app/Functions/CreateDocument.php index 7eb3a129..831a9692 100644 --- a/Modules/LlmDriver/app/Functions/CreateDocument.php +++ b/Modules/LlmDriver/app/Functions/CreateDocument.php @@ -13,6 +13,11 @@ class CreateDocument extends FunctionContract { use ChatHelperTrait, ToolsHelper; + public array $toolTypes = [ + ToolTypes::Source, + ToolTypes::Chat, + ]; + protected string $name = 'create_document'; protected string $description = 'Create a document in the collection of this local system'; diff --git a/Modules/LlmDriver/app/Functions/FunctionContract.php b/Modules/LlmDriver/app/Functions/FunctionContract.php index c7bf4eec..b2f0a302 100644 --- a/Modules/LlmDriver/app/Functions/FunctionContract.php +++ b/Modules/LlmDriver/app/Functions/FunctionContract.php @@ -10,6 +10,14 @@ abstract class FunctionContract { protected string $name; + public array $toolTypes = [ + ToolTypes::Chat, + ToolTypes::ChatCompletion, + ToolTypes::Source, + ToolTypes::Output, + ToolTypes::NoFunction, + ]; + protected string $description; protected string $type = 'object'; diff --git a/Modules/LlmDriver/app/Functions/GatherInfoTool.php b/Modules/LlmDriver/app/Functions/GatherInfoTool.php index 69d5997d..e3a1ce9e 100644 --- a/Modules/LlmDriver/app/Functions/GatherInfoTool.php +++ b/Modules/LlmDriver/app/Functions/GatherInfoTool.php @@ -20,6 +20,10 @@ class GatherInfoTool extends FunctionContract { use ToolsHelper; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ]; + protected string $name = 'gather_info_tool'; protected string $description = 'This will look at all documents using your prompt then return the results after once more using your prompt'; diff --git a/Modules/LlmDriver/app/Functions/GetWebSiteFromUrlTool.php b/Modules/LlmDriver/app/Functions/GetWebSiteFromUrlTool.php index 7599e7ed..cce2b0ee 100644 --- a/Modules/LlmDriver/app/Functions/GetWebSiteFromUrlTool.php +++ b/Modules/LlmDriver/app/Functions/GetWebSiteFromUrlTool.php @@ -12,6 +12,13 @@ class GetWebSiteFromUrlTool extends FunctionContract { use ToolsHelper; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ToolTypes::Source, + ToolTypes::Output, + ToolTypes::Chat, + ]; + protected string $name = 'get_web_site_from_url'; protected string $description = 'If you add urls to a prompt and ask the llm to get the web site using the url(s) you give it'; diff --git a/Modules/LlmDriver/app/Functions/ReportingTool.php b/Modules/LlmDriver/app/Functions/ReportingTool.php index 6d3585f0..f2593974 100644 --- a/Modules/LlmDriver/app/Functions/ReportingTool.php +++ b/Modules/LlmDriver/app/Functions/ReportingTool.php @@ -21,6 +21,10 @@ class ReportingTool extends FunctionContract { use ToolsHelper; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ]; + protected string $name = 'reporting_tool'; protected string $description = 'Uses Reference collection to generate a report'; diff --git a/Modules/LlmDriver/app/Functions/RetrieveRelated.php b/Modules/LlmDriver/app/Functions/RetrieveRelated.php index dc662381..ba833ff3 100644 --- a/Modules/LlmDriver/app/Functions/RetrieveRelated.php +++ b/Modules/LlmDriver/app/Functions/RetrieveRelated.php @@ -18,6 +18,10 @@ class RetrieveRelated extends FunctionContract { use CreateReferencesTrait; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ]; + protected string $name = 'retrieve_related'; protected string $description = 'Used to embed users prompt, search local database and return summarized results. diff --git a/Modules/LlmDriver/app/Functions/SearchTheWeb.php b/Modules/LlmDriver/app/Functions/SearchTheWeb.php index 6684f55a..d069a1bc 100644 --- a/Modules/LlmDriver/app/Functions/SearchTheWeb.php +++ b/Modules/LlmDriver/app/Functions/SearchTheWeb.php @@ -16,6 +16,13 @@ class SearchTheWeb extends FunctionContract { use ChatHelperTrait, ToolsHelper; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ToolTypes::Source, + ToolTypes::Output, + ToolTypes::Chat, + ]; + protected string $name = 'search_the_web'; protected string $description = 'Search the web for a topic'; diff --git a/Modules/LlmDriver/app/Functions/StandardsChecker.php b/Modules/LlmDriver/app/Functions/StandardsChecker.php index 0b269616..6a2119e9 100644 --- a/Modules/LlmDriver/app/Functions/StandardsChecker.php +++ b/Modules/LlmDriver/app/Functions/StandardsChecker.php @@ -11,6 +11,10 @@ class StandardsChecker extends FunctionContract { + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ]; + protected string $name = 'standards_checker'; protected string $description = 'Checks the prompt data follows the standards of the documents in the collection'; diff --git a/Modules/LlmDriver/app/Functions/SummarizeCollection.php b/Modules/LlmDriver/app/Functions/SummarizeCollection.php index 22e62247..ae73b2c1 100644 --- a/Modules/LlmDriver/app/Functions/SummarizeCollection.php +++ b/Modules/LlmDriver/app/Functions/SummarizeCollection.php @@ -17,6 +17,10 @@ class SummarizeCollection extends FunctionContract protected string $response = ''; + public array $toolTypes = [ + ToolTypes::ChatCompletion, + ]; + public function handle( Message $message): FunctionResponse { diff --git a/Modules/LlmDriver/app/Functions/ToolTypes.php b/Modules/LlmDriver/app/Functions/ToolTypes.php new file mode 100644 index 00000000..9e7d8870 --- /dev/null +++ b/Modules/LlmDriver/app/Functions/ToolTypes.php @@ -0,0 +1,16 @@ +toolType = $toolType; + + return $this; + } + public function driver($name = null) { $name = $name ?: $this->getDefaultDriver(); @@ -61,17 +72,28 @@ protected function getDefaultDriver() public function getFunctions(): array { - return [ - (new SummarizeCollection())->getFunction(), - (new RetrieveRelated())->getFunction(), - (new StandardsChecker())->getFunction(), - (new ReportingTool())->getFunction(), - (new GatherInfoTool())->getFunction(), - (new GetWebSiteFromUrlTool())->getFunction(), - (new SearchTheWeb())->getFunction(), - (new CreateDocument())->getFunction(), - (new Chat())->getFunction(), - ]; + $functions = collect( + [ + new SummarizeCollection(), + new RetrieveRelated(), + new StandardsChecker(), + new ReportingTool(), + new GatherInfoTool(), + new GetWebSiteFromUrlTool(), + new SearchTheWeb(), + new CreateDocument(), + new Chat(), + ] + ); + + if (isset($this->toolType)) { + $functions = $functions->filter(function (FunctionContract $function) { + return in_array($this->toolType, $function->toolTypes); + }); + } + + return $functions->toArray(); + } public function getFunctionsForUi(): array diff --git a/Modules/LlmDriver/tests/Feature/LlmDriverClientTest.php b/Modules/LlmDriver/tests/Feature/LlmDriverClientTest.php index ef4ad2ca..197a80f0 100644 --- a/Modules/LlmDriver/tests/Feature/LlmDriverClientTest.php +++ b/Modules/LlmDriver/tests/Feature/LlmDriverClientTest.php @@ -2,6 +2,7 @@ namespace Tests\Feature; +use LlmLaraHub\LlmDriver\Functions\ToolTypes; use LlmLaraHub\LlmDriver\LlmDriverFacade; use LlmLaraHub\LlmDriver\MockClient; use LlmLaraHub\LlmDriver\OpenAiClient; @@ -28,6 +29,20 @@ public function test_driver_openai(): void public function test_get_functions() { - $this->assertNotEmpty(LlmDriverFacade::getFunctions()); + $functions = LlmDriverFacade::getFunctions(); + + $this->assertCount(9, $functions); + + $function = LlmDriverFacade::setToolType( + ToolTypes::ChatCompletion + )->getFunctions(); + + $this->assertCount(7, $function); + + $function = LlmDriverFacade::setToolType( + ToolTypes::Chat + )->getFunctions(); + + $this->assertCount(4, $function); } } diff --git a/app/Domains/Orchestration/OrchestrateVersionTwo.php b/app/Domains/Orchestration/OrchestrateVersionTwo.php index 4dfe390c..2a0a876c 100644 --- a/app/Domains/Orchestration/OrchestrateVersionTwo.php +++ b/app/Domains/Orchestration/OrchestrateVersionTwo.php @@ -12,6 +12,7 @@ use Illuminate\Bus\Batch; use Illuminate\Support\Facades\Bus; use Illuminate\Support\Facades\Log; +use LlmLaraHub\LlmDriver\Functions\ToolTypes; use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait; use LlmLaraHub\LlmDriver\LlmDriverFacade; use LlmLaraHub\LlmDriver\ToolsHelper; @@ -102,6 +103,7 @@ public function handle( put_fixture('orchestrate_messages_fist_send.json', $messages); $response = LlmDriverFacade::driver($message->getDriver()) + ->setToolTypes(ToolTypes::ChatCompletion) ->chat($messages); if (! empty($response->tool_calls)) {