Skip to content

Commit

Permalink
adding Types so I can limit the tools and the confusion that comes wi…
Browse files Browse the repository at this point in the history
…th that
  • Loading branch information
alnutile committed Aug 4, 2024
1 parent 24d5461 commit 755feb8
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 12 deletions.
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/Chat.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class Chat extends FunctionContract
{
use ChatHelperTrait, ToolsHelper;

public array $toolTypes = [
ToolTypes::Chat,
];

protected string $name = 'chat_only';

protected string $description = 'User just wants to continue the chat no need to look in the collection for more documents';
Expand Down
5 changes: 5 additions & 0 deletions Modules/LlmDriver/app/Functions/CreateDocument.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class CreateDocument extends FunctionContract
{
use ChatHelperTrait, ToolsHelper;

public array $toolTypes = [
ToolTypes::Source,
ToolTypes::Chat,
];

protected string $name = 'create_document';

protected string $description = 'Create a document in the collection of this local system';
Expand Down
8 changes: 8 additions & 0 deletions Modules/LlmDriver/app/Functions/FunctionContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ abstract class FunctionContract
{
protected string $name;

public array $toolTypes = [
ToolTypes::Chat,
ToolTypes::ChatCompletion,
ToolTypes::Source,
ToolTypes::Output,
ToolTypes::NoFunction,
];

protected string $description;

protected string $type = 'object';
Expand Down
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/GatherInfoTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class GatherInfoTool extends FunctionContract
{
use ToolsHelper;

public array $toolTypes = [
ToolTypes::ChatCompletion,
];

protected string $name = 'gather_info_tool';

protected string $description = 'This will look at all documents using your prompt then return the results after once more using your prompt';
Expand Down
7 changes: 7 additions & 0 deletions Modules/LlmDriver/app/Functions/GetWebSiteFromUrlTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class GetWebSiteFromUrlTool extends FunctionContract
{
use ToolsHelper;

public array $toolTypes = [
ToolTypes::ChatCompletion,
ToolTypes::Source,
ToolTypes::Output,
ToolTypes::Chat,
];

protected string $name = 'get_web_site_from_url';

protected string $description = 'If you add urls to a prompt and ask the llm to get the web site using the url(s) you give it';
Expand Down
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/ReportingTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class ReportingTool extends FunctionContract
{
use ToolsHelper;

public array $toolTypes = [
ToolTypes::ChatCompletion,
];

protected string $name = 'reporting_tool';

protected string $description = 'Uses Reference collection to generate a report';
Expand Down
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/RetrieveRelated.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class RetrieveRelated extends FunctionContract
{
use CreateReferencesTrait;

public array $toolTypes = [
ToolTypes::ChatCompletion,
];

protected string $name = 'retrieve_related';

protected string $description = 'Used to embed users prompt, search local database and return summarized results.
Expand Down
7 changes: 7 additions & 0 deletions Modules/LlmDriver/app/Functions/SearchTheWeb.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class SearchTheWeb extends FunctionContract
{
use ChatHelperTrait, ToolsHelper;

public array $toolTypes = [
ToolTypes::ChatCompletion,
ToolTypes::Source,
ToolTypes::Output,
ToolTypes::Chat,
];

protected string $name = 'search_the_web';

protected string $description = 'Search the web for a topic';
Expand Down
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/StandardsChecker.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

class StandardsChecker extends FunctionContract
{
public array $toolTypes = [
ToolTypes::ChatCompletion,
];

protected string $name = 'standards_checker';

protected string $description = 'Checks the prompt data follows the standards of the documents in the collection';
Expand Down
4 changes: 4 additions & 0 deletions Modules/LlmDriver/app/Functions/SummarizeCollection.php
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class SummarizeCollection extends FunctionContract

protected string $response = '';

public array $toolTypes = [
ToolTypes::ChatCompletion,
];

public function handle(
Message $message): FunctionResponse
{
Expand Down
16 changes: 16 additions & 0 deletions Modules/LlmDriver/app/Functions/ToolTypes.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?php

namespace LlmLaraHub\LlmDriver\Functions;

use App\Helpers\EnumHelperTrait;

enum ToolTypes: string
{
use EnumHelperTrait;

case Chat = 'chat';
case ChatCompletion = 'chat_completion';
case Source = 'source';
case Output = 'output';
case NoFunction = 'no_function';
}
44 changes: 33 additions & 11 deletions Modules/LlmDriver/app/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,29 @@

use LlmLaraHub\LlmDriver\Functions\Chat;
use LlmLaraHub\LlmDriver\Functions\CreateDocument;
use LlmLaraHub\LlmDriver\Functions\FunctionContract;
use LlmLaraHub\LlmDriver\Functions\GatherInfoTool;
use LlmLaraHub\LlmDriver\Functions\GetWebSiteFromUrlTool;
use LlmLaraHub\LlmDriver\Functions\ReportingTool;
use LlmLaraHub\LlmDriver\Functions\RetrieveRelated;
use LlmLaraHub\LlmDriver\Functions\SearchTheWeb;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
use LlmLaraHub\LlmDriver\Functions\SummarizeCollection;
use LlmLaraHub\LlmDriver\Functions\ToolTypes;

class LlmDriverClient
{
protected $drivers = [];

protected ToolTypes $toolType;

public function setToolType(ToolTypes $toolType): self
{
$this->toolType = $toolType;

return $this;
}

public function driver($name = null)
{
$name = $name ?: $this->getDefaultDriver();
Expand Down Expand Up @@ -61,17 +72,28 @@ protected function getDefaultDriver()

public function getFunctions(): array
{
return [
(new SummarizeCollection())->getFunction(),
(new RetrieveRelated())->getFunction(),
(new StandardsChecker())->getFunction(),
(new ReportingTool())->getFunction(),
(new GatherInfoTool())->getFunction(),
(new GetWebSiteFromUrlTool())->getFunction(),
(new SearchTheWeb())->getFunction(),
(new CreateDocument())->getFunction(),
(new Chat())->getFunction(),
];
$functions = collect(
[
new SummarizeCollection(),
new RetrieveRelated(),
new StandardsChecker(),
new ReportingTool(),
new GatherInfoTool(),
new GetWebSiteFromUrlTool(),
new SearchTheWeb(),
new CreateDocument(),
new Chat(),
]
);

if (isset($this->toolType)) {
$functions = $functions->filter(function (FunctionContract $function) {
return in_array($this->toolType, $function->toolTypes);
});
}

return $functions->toArray();

}

public function getFunctionsForUi(): array
Expand Down
17 changes: 16 additions & 1 deletion Modules/LlmDriver/tests/Feature/LlmDriverClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace Tests\Feature;

use LlmLaraHub\LlmDriver\Functions\ToolTypes;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\MockClient;
use LlmLaraHub\LlmDriver\OpenAiClient;
Expand All @@ -28,6 +29,20 @@ public function test_driver_openai(): void

public function test_get_functions()
{
$this->assertNotEmpty(LlmDriverFacade::getFunctions());
$functions = LlmDriverFacade::getFunctions();

$this->assertCount(9, $functions);

$function = LlmDriverFacade::setToolType(
ToolTypes::ChatCompletion
)->getFunctions();

$this->assertCount(7, $function);

$function = LlmDriverFacade::setToolType(
ToolTypes::Chat
)->getFunctions();

$this->assertCount(4, $function);
}
}
2 changes: 2 additions & 0 deletions app/Domains/Orchestration/OrchestrateVersionTwo.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
use Illuminate\Bus\Batch;
use Illuminate\Support\Facades\Bus;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Functions\ToolTypes;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\ToolsHelper;
Expand Down Expand Up @@ -102,6 +103,7 @@ public function handle(
put_fixture('orchestrate_messages_fist_send.json', $messages);

$response = LlmDriverFacade::driver($message->getDriver())
->setToolTypes(ToolTypes::ChatCompletion)
->chat($messages);

if (! empty($response->tool_calls)) {
Expand Down

0 comments on commit 755feb8

Please sign in to comment.