From 351f6efaf1ff83ce3f165486b58c093682db7769 Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Sat, 30 Mar 2024 19:38:51 -0400 Subject: [PATCH 01/10] Start the function work --- app/Http/Controllers/ChatController.php | 11 +++++++ app/LlmDriver/BaseClient.php | 5 ++++ app/LlmDriver/Functions/FunctionContract.php | 13 +++++---- app/LlmDriver/Functions/ParametersDto.php | 4 +-- .../{ParameterDto.php => PropertyDto.php} | 2 +- .../Functions/SearchAndSummarize.php | 5 ++-- app/LlmDriver/LlmDriverClient.php | 9 ++++++ app/LlmDriver/LlmServiceProvider.php | 28 ++++++++++++++++++ app/LlmDriver/OpenAiClient.php | 17 +++++++++++ app/LlmDriver/readme.md | 24 +++++++++++++++ app/Providers/AppServiceProvider.php | 5 ---- bootstrap/providers.php | 1 + tests/Feature/FunctionDtoTest.php | 17 ++++++----- tests/Feature/LlmDriverClientTest.php | 4 +++ tests/Feature/OpenAiClientTest.php | 18 ++++++++++-- tests/Feature/SearchAndSummarizeTest.php | 29 +++++++++++++++++++ 16 files changed, 164 insertions(+), 28 deletions(-) rename app/LlmDriver/Functions/{ParameterDto.php => PropertyDto.php} (85%) create mode 100644 app/LlmDriver/LlmServiceProvider.php create mode 100644 app/LlmDriver/readme.md create mode 100644 tests/Feature/SearchAndSummarizeTest.php diff --git a/app/Http/Controllers/ChatController.php b/app/Http/Controllers/ChatController.php index 5ea2bd99..76b58c0a 100644 --- a/app/Http/Controllers/ChatController.php +++ b/app/Http/Controllers/ChatController.php @@ -6,6 +6,7 @@ use App\Http\Resources\ChatResource; use App\Http\Resources\CollectionResource; use App\Http\Resources\MessageResource; +use App\LlmDriver\LlmDriverFacade; use App\Models\Chat; use App\Models\Collection; use Facades\App\Domains\Messages\SearchOrSummarizeChatRepo; @@ -44,6 +45,16 @@ public function chat(Chat $chat) 'input' => 'required|string', ]); + //get all the functions we have + + $response = LlmDriverFacade::driver( + $chat->chatable->getDriver() + )->chat($validated['input']); + //attach them the the initial request + //see if we get results or a function request + //if we get a function request, we run the function + //if we get results, we return the results + $response = SearchOrSummarizeChatRepo::search($chat, $validated['input']); ChatUpdatedEvent::dispatch($chat->chatable, $chat); diff --git a/app/LlmDriver/BaseClient.php b/app/LlmDriver/BaseClient.php index 3ca6a678..7013589b 100644 --- a/app/LlmDriver/BaseClient.php +++ b/app/LlmDriver/BaseClient.php @@ -63,4 +63,9 @@ protected function getConfig(string $driver): array { return config("llmdriver.drivers.$driver"); } + + + public function getFunctions() : array { + return []; + } } diff --git a/app/LlmDriver/Functions/FunctionContract.php b/app/LlmDriver/Functions/FunctionContract.php index d4e86390..d08ee757 100644 --- a/app/LlmDriver/Functions/FunctionContract.php +++ b/app/LlmDriver/Functions/FunctionContract.php @@ -8,6 +8,8 @@ abstract class FunctionContract protected string $dscription; + protected string $type = 'object'; + /** * @param array $data * @return array @@ -20,10 +22,12 @@ public function getFunction(): FunctionDto [ 'name' => $this->getName(), 'description' => $this->getDescription(), - 'parameters' => $this->getParameters(), + 'parameters' => [ + 'type' => $this->type, + 'properties' => $this->getProperties(), + ], ] ); - } protected function getName(): string @@ -34,10 +38,7 @@ protected function getName(): string /** * @return ParameterDto[] */ - protected function getParameters(): array - { - return []; - } + abstract protected function getProperties(): array; protected function getDescription(): string { diff --git a/app/LlmDriver/Functions/ParametersDto.php b/app/LlmDriver/Functions/ParametersDto.php index 3225a25f..57827861 100644 --- a/app/LlmDriver/Functions/ParametersDto.php +++ b/app/LlmDriver/Functions/ParametersDto.php @@ -5,12 +5,12 @@ class ParametersDto extends \Spatie\LaravelData\Data { /** - * @param ParameterDto[] $parameters + * @param PropertyDto[] $properties * @return void */ public function __construct( public string $type = 'object', - public array $parameters = [], + public array $properties = [], ) { } } diff --git a/app/LlmDriver/Functions/ParameterDto.php b/app/LlmDriver/Functions/PropertyDto.php similarity index 85% rename from app/LlmDriver/Functions/ParameterDto.php rename to app/LlmDriver/Functions/PropertyDto.php index c7b29d8c..2dde30c3 100644 --- a/app/LlmDriver/Functions/ParameterDto.php +++ b/app/LlmDriver/Functions/PropertyDto.php @@ -2,7 +2,7 @@ namespace App\LlmDriver\Functions; -class ParameterDto extends \Spatie\LaravelData\Data +class PropertyDto extends \Spatie\LaravelData\Data { public function __construct( public string $name, diff --git a/app/LlmDriver/Functions/SearchAndSummarize.php b/app/LlmDriver/Functions/SearchAndSummarize.php index 2fe5e129..4b795313 100644 --- a/app/LlmDriver/Functions/SearchAndSummarize.php +++ b/app/LlmDriver/Functions/SearchAndSummarize.php @@ -10,17 +10,16 @@ class SearchAndSummarize extends FunctionContract public function handle(FunctionCallDto $functionCallDto): array { - return []; } /** * @return ParameterDto[] */ - protected function getParameters(): array + protected function getProperties(): array { return [ - new ParameterDto( + new PropertyDto( name: 'prompt', description: 'The prompt to search for in the database.', type: 'string', diff --git a/app/LlmDriver/LlmDriverClient.php b/app/LlmDriver/LlmDriverClient.php index 228c3241..b24e9593 100644 --- a/app/LlmDriver/LlmDriverClient.php +++ b/app/LlmDriver/LlmDriverClient.php @@ -2,6 +2,8 @@ namespace App\LlmDriver; +use App\LlmDriver\Functions\SearchAndSummarize; + class LlmDriverClient { protected $drivers = []; @@ -46,4 +48,11 @@ protected function getDefaultDriver() { return 'mock'; } + + public function getFunctions() : array { + return [ + (new SearchAndSummarize())->getFunction(), + ]; + } + } diff --git a/app/LlmDriver/LlmServiceProvider.php b/app/LlmDriver/LlmServiceProvider.php new file mode 100644 index 00000000..e83b278b --- /dev/null +++ b/app/LlmDriver/LlmServiceProvider.php @@ -0,0 +1,28 @@ +app->bind('llm_driver', function () { + return new LlmDriverClient(); + }); + + } +} diff --git a/app/LlmDriver/OpenAiClient.php b/app/LlmDriver/OpenAiClient.php index f83c7579..863d76c5 100644 --- a/app/LlmDriver/OpenAiClient.php +++ b/app/LlmDriver/OpenAiClient.php @@ -16,6 +16,9 @@ class OpenAiClient extends BaseClient */ public function chat(array $messages): CompletionResponse { + $functions = $this->getFunctions(); + + $response = OpenAI::chat()->create([ 'model' => $this->getConfig('openai')['completion_model'], 'messages' => collect($messages)->map(function ($message) { @@ -69,4 +72,18 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp return new CompletionResponse($results); } + + /** + * @NOTE + * Since this abstraction layer is based on OpenAi + * Not much needs to happen here + * but on the others I might need to do XML? + * @return array + */ + public function getFunctions() : array { + $functions = LlmDriverFacade::getFunctions(); + return collect($functions)->map(function ($function) { + return $function->toArray(); + })->toArray(); + } } diff --git a/app/LlmDriver/readme.md b/app/LlmDriver/readme.md new file mode 100644 index 00000000..5ebfc815 --- /dev/null +++ b/app/LlmDriver/readme.md @@ -0,0 +1,24 @@ +# LLmDriver + +## Todo move this to the Laravel Module library + +[https://github.com/nWidart/laravel-modules](https://github.com/nWidart/laravel-modules) + + +For now just taking notes of things I need to remember. + + +Load Provider. + +```php +isAdmin(); }); - $this->app->bind('llm_driver', function () { - return new LlmDriverClient(); - }); - } } diff --git a/bootstrap/providers.php b/bootstrap/providers.php index 4b3883c0..51a15e98 100644 --- a/bootstrap/providers.php +++ b/bootstrap/providers.php @@ -5,4 +5,5 @@ App\Providers\FortifyServiceProvider::class, App\Providers\HorizonServiceProvider::class, App\Providers\JetstreamServiceProvider::class, + App\LlmDriver\LlmServiceProvider::class, ]; diff --git a/tests/Feature/FunctionDtoTest.php b/tests/Feature/FunctionDtoTest.php index 71e6c394..4b8d8f52 100644 --- a/tests/Feature/FunctionDtoTest.php +++ b/tests/Feature/FunctionDtoTest.php @@ -3,8 +3,8 @@ namespace Tests\Feature; use App\LlmDriver\Functions\FunctionDto; -use App\LlmDriver\Functions\ParameterDto; use App\LlmDriver\Functions\ParametersDto; +use App\LlmDriver\Functions\PropertyDto; use Tests\TestCase; class FunctionDtoTest extends TestCase @@ -20,7 +20,7 @@ public function test_dto(): void 'description' => 'test', 'parameters' => [ 'type' => 'object', - 'parameters' => [ + 'properties' => [ [ 'name' => 'test', 'description' => 'test', @@ -45,9 +45,9 @@ public function test_dto(): void $this->assertNotNull($dto->name); $this->assertNotNull($dto->description); $this->assertInstanceOf(ParametersDto::class, $dto->parameters); - $this->assertCount(2, $dto->parameters->parameters); - $parameterOne = $dto->parameters->parameters[0]; - $this->assertInstanceOf(ParameterDto::class, $parameterOne); + $this->assertCount(2, $dto->parameters->properties); + $parameterOne = $dto->parameters->properties[0]; + $this->assertInstanceOf(PropertyDto::class, $parameterOne); $this->assertEquals('test', $parameterOne->name); $this->assertEquals('test', $parameterOne->description); $this->assertEquals('string', $parameterOne->type); @@ -55,14 +55,15 @@ public function test_dto(): void $this->assertEquals('', $parameterOne->default); $this->assertFalse($parameterOne->required); - $parameterTwo = $dto->parameters->parameters[1]; - $this->assertInstanceOf(ParameterDto::class, $parameterTwo); + $parameterTwo = $dto->parameters->properties[1]; + $this->assertInstanceOf(PropertyDto::class, $parameterTwo); $this->assertEquals('test2', $parameterTwo->name); $this->assertEquals('test2', $parameterTwo->description); $this->assertEquals('string', $parameterTwo->type); $this->assertEquals(['foo', 'bar'], $parameterTwo->enum); $this->assertEquals('bar', $parameterTwo->default); $this->assertTrue($parameterTwo->required); - } + + } diff --git a/tests/Feature/LlmDriverClientTest.php b/tests/Feature/LlmDriverClientTest.php index 4a8b7970..0035ec0e 100644 --- a/tests/Feature/LlmDriverClientTest.php +++ b/tests/Feature/LlmDriverClientTest.php @@ -25,4 +25,8 @@ public function test_driver_openai(): void $this->assertInstanceOf(OpenAiClient::class, $results); } + + public function test_get_functions() { + $this->assertNotEmpty(LlmDriverFacade::getFunctions()); + } } diff --git a/tests/Feature/OpenAiClientTest.php b/tests/Feature/OpenAiClientTest.php index 6d93a2d3..5eac79c0 100644 --- a/tests/Feature/OpenAiClientTest.php +++ b/tests/Feature/OpenAiClientTest.php @@ -12,9 +12,21 @@ class OpenAiClientTest extends TestCase { - /** - * A basic feature test example. - */ + + public function test_get_functions(): void + { + $openaiClient = new \App\LlmDriver\OpenAiClient(); + $response = $openaiClient->getFunctions(); + $this->assertNotEmpty($response); + $this->assertIsArray($response); + $first = $response[0]; + $this->assertArrayHasKey('name', $first); + $this->assertArrayHasKey('description', $first); + $this->assertArrayHasKey('parameters', $first); + + + } + public function test_openai_client(): void { OpenAI::fake([ diff --git a/tests/Feature/SearchAndSummarizeTest.php b/tests/Feature/SearchAndSummarizeTest.php new file mode 100644 index 00000000..312fd505 --- /dev/null +++ b/tests/Feature/SearchAndSummarizeTest.php @@ -0,0 +1,29 @@ +getFunction(); + + $parameters = $function->parameters; + + $this->assertInstanceOf(ParametersDto::class, $parameters); + $this->assertIsArray($parameters->properties); + $this->assertInstanceOf(PropertyDto::class, $parameters->properties[0]); + } +} From abfdfcbe9e9d81ff9df8d24e522886bc87a44ecd Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Mon, 1 Apr 2024 21:25:25 -0400 Subject: [PATCH 02/10] add function helper --- app/LlmDriver/BaseClient.php | 40 +++++++++++++++++- app/LlmDriver/Functions/FunctionContract.php | 20 ++++----- .../Functions/SearchAndSummarize.php | 6 +-- app/LlmDriver/LlmDriverClient.php | 4 +- app/LlmDriver/LlmServiceProvider.php | 1 - app/LlmDriver/OpenAiClient.php | 42 ++++++++++++++++--- tests/Feature/FunctionDtoTest.php | 2 - tests/Feature/LlmDriverClientTest.php | 3 +- tests/Feature/MockClientTest.php | 24 +++++++++-- tests/Feature/OpenAiClientTest.php | 8 ++-- tests/Feature/SearchAndSummarizeTest.php | 3 -- tests/fixtures/claude_messages_debug.json | 20 +-------- .../fixtures/openai_client_get_functions.json | 23 ++++++++++ .../openai_response_with_functions.json | 33 +++++++++++++++ 14 files changed, 174 insertions(+), 55 deletions(-) create mode 100644 tests/fixtures/openai_client_get_functions.json create mode 100644 tests/fixtures/openai_response_with_functions.json diff --git a/app/LlmDriver/BaseClient.php b/app/LlmDriver/BaseClient.php index 7013589b..1c6f1e55 100644 --- a/app/LlmDriver/BaseClient.php +++ b/app/LlmDriver/BaseClient.php @@ -26,6 +26,42 @@ public function embedData(string $data): EmbeddingsResponseDto ]); } + /** + * This is to get functions out of the llm + * if none are returned your system + * can error out or try another way. + * + * @param MessageInDto[] $messages + */ + public function functionPromptChat(array $messages, array $only = []): array + { + if (! app()->environment('testing')) { + sleep(2); + } + + Log::info('LlmDriver::MockClient::functionPromptChat', $messages); + + $data = get_fixture('openai_response_with_functions.json'); + + $functions = []; + + foreach (data_get($data, 'choices', []) as $choice) { + foreach (data_get($choice, 'message.toolCalls', []) as $tool) { + if (data_get($tool, 'type') === 'function') { + $name = data_get($tool, 'function.name', null); + if (! in_array($name, $only)) { + $functions[] = [ + 'name' => $name, + 'arguments' => json_decode(data_get($tool, 'function.arguments', []), true), + ]; + } + } + } + } + + return $functions; + } + /** * @param MessageInDto[] $messages */ @@ -64,8 +100,8 @@ protected function getConfig(string $driver): array return config("llmdriver.drivers.$driver"); } - - public function getFunctions() : array { + public function getFunctions(): array + { return []; } } diff --git a/app/LlmDriver/Functions/FunctionContract.php b/app/LlmDriver/Functions/FunctionContract.php index d08ee757..55033320 100644 --- a/app/LlmDriver/Functions/FunctionContract.php +++ b/app/LlmDriver/Functions/FunctionContract.php @@ -2,18 +2,16 @@ namespace App\LlmDriver\Functions; +use App\LlmDriver\Functions\PropertyDto; + abstract class FunctionContract { protected string $name; - protected string $dscription; + protected string $description; protected string $type = 'object'; - /** - * @param array $data - * @return array - */ abstract public function handle(FunctionCallDto $functionCallDto): array; public function getFunction(): FunctionDto @@ -35,13 +33,13 @@ protected function getName(): string return $this->name; } - /** - * @return ParameterDto[] - */ - abstract protected function getProperties(): array; - protected function getDescription(): string { - return $this->name; + return $this->description; } + + /** + * @return PropertyDto[] + */ + abstract protected function getProperties(): array; } diff --git a/app/LlmDriver/Functions/SearchAndSummarize.php b/app/LlmDriver/Functions/SearchAndSummarize.php index 4b795313..9c63c9e8 100644 --- a/app/LlmDriver/Functions/SearchAndSummarize.php +++ b/app/LlmDriver/Functions/SearchAndSummarize.php @@ -6,7 +6,7 @@ class SearchAndSummarize extends FunctionContract { protected string $name = 'search_and_summarize'; - protected string $dscription = 'Used to embed users prompt, search database and return summarized results.'; + protected string $description = 'Used to embed users prompt, search database and return summarized results.'; public function handle(FunctionCallDto $functionCallDto): array { @@ -14,14 +14,14 @@ public function handle(FunctionCallDto $functionCallDto): array } /** - * @return ParameterDto[] + * @return PropertyDto[] */ protected function getProperties(): array { return [ new PropertyDto( name: 'prompt', - description: 'The prompt to search for in the database.', + description: 'The prompt the user is using the search for.', type: 'string', required: true, ), diff --git a/app/LlmDriver/LlmDriverClient.php b/app/LlmDriver/LlmDriverClient.php index b24e9593..d432249f 100644 --- a/app/LlmDriver/LlmDriverClient.php +++ b/app/LlmDriver/LlmDriverClient.php @@ -49,10 +49,10 @@ protected function getDefaultDriver() return 'mock'; } - public function getFunctions() : array { + public function getFunctions(): array + { return [ (new SearchAndSummarize())->getFunction(), ]; } - } diff --git a/app/LlmDriver/LlmServiceProvider.php b/app/LlmDriver/LlmServiceProvider.php index e83b278b..6dbf9e22 100644 --- a/app/LlmDriver/LlmServiceProvider.php +++ b/app/LlmDriver/LlmServiceProvider.php @@ -2,7 +2,6 @@ namespace App\LlmDriver; -use App\LlmDriver\LlmDriverClient; use Illuminate\Support\ServiceProvider; class LlmServiceProvider extends ServiceProvider diff --git a/app/LlmDriver/OpenAiClient.php b/app/LlmDriver/OpenAiClient.php index 863d76c5..bc2d324c 100644 --- a/app/LlmDriver/OpenAiClient.php +++ b/app/LlmDriver/OpenAiClient.php @@ -18,12 +18,13 @@ public function chat(array $messages): CompletionResponse { $functions = $this->getFunctions(); - $response = OpenAI::chat()->create([ - 'model' => $this->getConfig('openai')['completion_model'], + 'model' => $this->getConfig('openai')['chat_model'], 'messages' => collect($messages)->map(function ($message) { return $message->toArray(); })->toArray(), + 'tool_choice' => 'auto', + 'tools' => $functions, ]); $results = null; @@ -78,12 +79,43 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp * Since this abstraction layer is based on OpenAi * Not much needs to happen here * but on the others I might need to do XML? - * @return array */ - public function getFunctions() : array { + public function getFunctions(): array + { $functions = LlmDriverFacade::getFunctions(); + return collect($functions)->map(function ($function) { - return $function->toArray(); + $function = $function->toArray(); + $properties = []; + $required = []; + + foreach (data_get($function, 'parameters.properties', []) as $property) { + $name = data_get($property, 'name'); + + if (data_get($property, 'required', false)) { + $required[] = $name; + } + + $properties[$name] = [ + 'description' => data_get($property, 'description', null), + 'type' => data_get($property, 'type', 'string'), + 'enum' => data_get($property, 'enum', []), + 'default' => data_get($property, 'default', null), + ]; + } + + return [ + 'type' => 'function', + 'function' => [ + 'name' => data_get($function, 'name'), + 'description' => data_get($function, 'description'), + 'parameters' => [ + 'type' => 'object', + 'properties' => $properties, + ], + 'required' => $required, + ], + ]; })->toArray(); } } diff --git a/tests/Feature/FunctionDtoTest.php b/tests/Feature/FunctionDtoTest.php index 4b8d8f52..6a725735 100644 --- a/tests/Feature/FunctionDtoTest.php +++ b/tests/Feature/FunctionDtoTest.php @@ -64,6 +64,4 @@ public function test_dto(): void $this->assertEquals('bar', $parameterTwo->default); $this->assertTrue($parameterTwo->required); } - - } diff --git a/tests/Feature/LlmDriverClientTest.php b/tests/Feature/LlmDriverClientTest.php index 0035ec0e..8a8a8334 100644 --- a/tests/Feature/LlmDriverClientTest.php +++ b/tests/Feature/LlmDriverClientTest.php @@ -26,7 +26,8 @@ public function test_driver_openai(): void $this->assertInstanceOf(OpenAiClient::class, $results); } - public function test_get_functions() { + public function test_get_functions() + { $this->assertNotEmpty(LlmDriverFacade::getFunctions()); } } diff --git a/tests/Feature/MockClientTest.php b/tests/Feature/MockClientTest.php index aa00dd4a..69a56f42 100644 --- a/tests/Feature/MockClientTest.php +++ b/tests/Feature/MockClientTest.php @@ -10,9 +10,27 @@ class MockClientTest extends TestCase { - /** - * A basic feature test example. - */ + public function test_tools(): void + { + + $client = new MockClient(); + + $results = $client->functionPromptChat(['test']); + + $this->assertCount(1, $results); + + $this->assertEquals('search_and_summarize', $results[0]['name']); + } + + public function test_tool_with_limit(): void + { + $client = new MockClient(); + + $results = $client->functionPromptChat(['test'], ['search_and_summarize']); + + $this->assertCount(0, $results); + } + public function test_embeddings(): void { diff --git a/tests/Feature/OpenAiClientTest.php b/tests/Feature/OpenAiClientTest.php index 5eac79c0..c1689f9b 100644 --- a/tests/Feature/OpenAiClientTest.php +++ b/tests/Feature/OpenAiClientTest.php @@ -12,7 +12,6 @@ class OpenAiClientTest extends TestCase { - public function test_get_functions(): void { $openaiClient = new \App\LlmDriver\OpenAiClient(); @@ -20,10 +19,11 @@ public function test_get_functions(): void $this->assertNotEmpty($response); $this->assertIsArray($response); $first = $response[0]; - $this->assertArrayHasKey('name', $first); - $this->assertArrayHasKey('description', $first); - $this->assertArrayHasKey('parameters', $first); + $this->assertArrayHasKey('type', $first); + $this->assertArrayHasKey('function', $first); + $expected = get_fixture('openai_client_get_functions.json'); + $this->assertEquals($expected, $response); } diff --git a/tests/Feature/SearchAndSummarizeTest.php b/tests/Feature/SearchAndSummarizeTest.php index 312fd505..25eba72a 100644 --- a/tests/Feature/SearchAndSummarizeTest.php +++ b/tests/Feature/SearchAndSummarizeTest.php @@ -2,11 +2,8 @@ namespace Tests\Feature; -use App\LlmDriver\Functions\ParameterDto; use App\LlmDriver\Functions\ParametersDto; use App\LlmDriver\Functions\PropertyDto; -use Illuminate\Foundation\Testing\RefreshDatabase; -use Illuminate\Foundation\Testing\WithFaker; use Tests\TestCase; class SearchAndSummarizeTest extends TestCase diff --git a/tests/fixtures/claude_messages_debug.json b/tests/fixtures/claude_messages_debug.json index 1c773dcf..a041105e 100644 --- a/tests/fixtures/claude_messages_debug.json +++ b/tests/fixtures/claude_messages_debug.json @@ -1,26 +1,10 @@ [ { - "content": "test", + "content": "how does the document define 'Generative AI'", "role": "user" }, { - "content": "test 1", - "role": "assistant" - }, - { - "role": "user", - "content": "Continuation of search results" - }, - { - "content": "test 2", - "role": "assistant" - }, - { - "role": "user", - "content": "Continuation of search results" - }, - { - "content": "test 3", + "content": "The user uploaded documents that the user will ask questions about. Please keep your answers related to the documents.", "role": "assistant" } ] \ No newline at end of file diff --git a/tests/fixtures/openai_client_get_functions.json b/tests/fixtures/openai_client_get_functions.json new file mode 100644 index 00000000..ae98cf53 --- /dev/null +++ b/tests/fixtures/openai_client_get_functions.json @@ -0,0 +1,23 @@ +[ + { + "type": "function", + "function": { + "name": "search_and_summarize", + "description": "Used to embed users prompt, search database and return summarized results.", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "description": "The prompt the user is using the search for.", + "type": "string", + "enum": [], + "default": "" + } + } + }, + "required": [ + "prompt" + ] + } + } +] \ No newline at end of file diff --git a/tests/fixtures/openai_response_with_functions.json b/tests/fixtures/openai_response_with_functions.json new file mode 100644 index 00000000..b6bbdcf4 --- /dev/null +++ b/tests/fixtures/openai_response_with_functions.json @@ -0,0 +1,33 @@ +{ + "id": "chatcmpl-99MsQDhcyADcsehLfCQAwuyhhTxmB", + "object": "chat.completion", + "created": 1712019918, + "model": "gpt-4-0125-preview", + "systemFingerprint": "fp_f38f4d6482", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "toolCalls": [ + { + "id": "call_u3GOeiE4LaSJvOqV2uOqeXK2", + "type": "function", + "function": { + "name": "search_and_summarize", + "arguments": "{\"prompt\":\"define 'Generative AI'\"}" + } + } + ], + "functionCall": null + }, + "finishReason": "tool_calls" + } + ], + "usage": { + "promptTokens": 101, + "completionTokens": 22, + "totalTokens": 123 + } +} \ No newline at end of file From 188c4cdb162192976dbe4f6b6322efb9e4c65d1d Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Mon, 1 Apr 2024 21:30:22 -0400 Subject: [PATCH 03/10] add function checker --- app/Http/Controllers/ChatController.php | 11 ----------- app/LlmDriver/Functions/FunctionContract.php | 2 -- phpunit.xml | 1 + tests/Feature/OpenAiClientTest.php | 1 - tests/fixtures/claude_messages_debug.json | 20 ++++++++++++++++++-- 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/app/Http/Controllers/ChatController.php b/app/Http/Controllers/ChatController.php index 76b58c0a..5ea2bd99 100644 --- a/app/Http/Controllers/ChatController.php +++ b/app/Http/Controllers/ChatController.php @@ -6,7 +6,6 @@ use App\Http\Resources\ChatResource; use App\Http\Resources\CollectionResource; use App\Http\Resources\MessageResource; -use App\LlmDriver\LlmDriverFacade; use App\Models\Chat; use App\Models\Collection; use Facades\App\Domains\Messages\SearchOrSummarizeChatRepo; @@ -45,16 +44,6 @@ public function chat(Chat $chat) 'input' => 'required|string', ]); - //get all the functions we have - - $response = LlmDriverFacade::driver( - $chat->chatable->getDriver() - )->chat($validated['input']); - //attach them the the initial request - //see if we get results or a function request - //if we get a function request, we run the function - //if we get results, we return the results - $response = SearchOrSummarizeChatRepo::search($chat, $validated['input']); ChatUpdatedEvent::dispatch($chat->chatable, $chat); diff --git a/app/LlmDriver/Functions/FunctionContract.php b/app/LlmDriver/Functions/FunctionContract.php index 55033320..94aad42d 100644 --- a/app/LlmDriver/Functions/FunctionContract.php +++ b/app/LlmDriver/Functions/FunctionContract.php @@ -2,8 +2,6 @@ namespace App\LlmDriver\Functions; -use App\LlmDriver\Functions\PropertyDto; - abstract class FunctionContract { protected string $name; diff --git a/phpunit.xml b/phpunit.xml index 236e4e5a..da2f836c 100644 --- a/phpunit.xml +++ b/phpunit.xml @@ -19,6 +19,7 @@ + diff --git a/tests/Feature/OpenAiClientTest.php b/tests/Feature/OpenAiClientTest.php index c1689f9b..4c7fb946 100644 --- a/tests/Feature/OpenAiClientTest.php +++ b/tests/Feature/OpenAiClientTest.php @@ -24,7 +24,6 @@ public function test_get_functions(): void $expected = get_fixture('openai_client_get_functions.json'); $this->assertEquals($expected, $response); - } public function test_openai_client(): void diff --git a/tests/fixtures/claude_messages_debug.json b/tests/fixtures/claude_messages_debug.json index a041105e..1c773dcf 100644 --- a/tests/fixtures/claude_messages_debug.json +++ b/tests/fixtures/claude_messages_debug.json @@ -1,10 +1,26 @@ [ { - "content": "how does the document define 'Generative AI'", + "content": "test", "role": "user" }, { - "content": "The user uploaded documents that the user will ask questions about. Please keep your answers related to the documents.", + "content": "test 1", + "role": "assistant" + }, + { + "role": "user", + "content": "Continuation of search results" + }, + { + "content": "test 2", + "role": "assistant" + }, + { + "role": "user", + "content": "Continuation of search results" + }, + { + "content": "test 3", "role": "assistant" } ] \ No newline at end of file From cdc00680f00262ab91dbe1eb1e285a9a93b15400 Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Thu, 4 Apr 2024 19:49:43 -0400 Subject: [PATCH 04/10] ok the first stab at orchestrate --- app/Events/ChatUiUpdateEvent.php | 44 +++++++ app/Http/Controllers/ChatController.php | 15 ++- app/LlmDriver/BaseClient.php | 4 + app/LlmDriver/Functions/ArgumentCaster.php | 17 +-- app/LlmDriver/Functions/FunctionCallDto.php | 6 +- app/LlmDriver/Functions/FunctionContract.php | 17 ++- .../Functions/SearchAndSummarize.php | 19 ++- .../Functions/SummarizeCollection.php | 46 +++++++ app/LlmDriver/LlmDriverClient.php | 2 + app/LlmDriver/LlmServiceProvider.php | 10 ++ app/LlmDriver/Orchestrate.php | 118 ++++++++++++++++++ app/LlmDriver/Responses/FunctionResponse.php | 11 ++ resources/js/Pages/Collection/Chat.vue | 8 ++ tests/Feature/FunctionCallDtoTest.php | 25 ++++ .../Http/Controllers/ChatControllerTest.php | 26 ++++ tests/Feature/OrchestrateTest.php | 77 ++++++++++++ tests/Feature/SummarizeCollectionTest.php | 26 ++++ 17 files changed, 450 insertions(+), 21 deletions(-) create mode 100644 app/Events/ChatUiUpdateEvent.php create mode 100644 app/LlmDriver/Functions/SummarizeCollection.php create mode 100644 app/LlmDriver/Orchestrate.php create mode 100644 app/LlmDriver/Responses/FunctionResponse.php create mode 100644 tests/Feature/FunctionCallDtoTest.php create mode 100644 tests/Feature/OrchestrateTest.php create mode 100644 tests/Feature/SummarizeCollectionTest.php diff --git a/app/Events/ChatUiUpdateEvent.php b/app/Events/ChatUiUpdateEvent.php new file mode 100644 index 00000000..7b1de258 --- /dev/null +++ b/app/Events/ChatUiUpdateEvent.php @@ -0,0 +1,44 @@ + + */ + public function broadcastOn(): array + { + return [ + new PrivateChannel('collection.chat.'.$this->collection->id.'.'.$this->chat->id), + ]; + } + + /** + * The event's broadcast name. + */ + public function broadcastAs(): string + { + return 'update'; + } +} diff --git a/app/Http/Controllers/ChatController.php b/app/Http/Controllers/ChatController.php index 5ea2bd99..2c36e6ec 100644 --- a/app/Http/Controllers/ChatController.php +++ b/app/Http/Controllers/ChatController.php @@ -2,13 +2,16 @@ namespace App\Http\Controllers; +use App\Events\ChatUiUpdateEvent; use App\Events\ChatUpdatedEvent; use App\Http\Resources\ChatResource; use App\Http\Resources\CollectionResource; use App\Http\Resources\MessageResource; +use App\LlmDriver\LlmDriverFacade; +use Facades\App\LlmDriver\Orchestrate; +use App\LlmDriver\Requests\MessageInDto; use App\Models\Chat; use App\Models\Collection; -use Facades\App\Domains\Messages\SearchOrSummarizeChatRepo; class ChatController extends Controller { @@ -44,7 +47,15 @@ public function chat(Chat $chat) 'input' => 'required|string', ]); - $response = SearchOrSummarizeChatRepo::search($chat, $validated['input']); + $messagesArray = []; + + $messagesArray[] = MessageInDto::from([ + 'content' => $validated['input'], + 'role' => 'user', + ]); + + $response = Orchestrate::handle($messagesArray, $chat); + ChatUpdatedEvent::dispatch($chat->chatable, $chat); diff --git a/app/LlmDriver/BaseClient.php b/app/LlmDriver/BaseClient.php index 1c6f1e55..b5e25a26 100644 --- a/app/LlmDriver/BaseClient.php +++ b/app/LlmDriver/BaseClient.php @@ -59,6 +59,10 @@ public function functionPromptChat(array $messages, array $only = []): array } } + /** + * @TODO + * make this a dto + */ return $functions; } diff --git a/app/LlmDriver/Functions/ArgumentCaster.php b/app/LlmDriver/Functions/ArgumentCaster.php index d6dc732b..1dbaa2ed 100644 --- a/app/LlmDriver/Functions/ArgumentCaster.php +++ b/app/LlmDriver/Functions/ArgumentCaster.php @@ -1,22 +1,17 @@ '' + ] + ); } /** @@ -21,9 +32,9 @@ protected function getProperties(): array return [ new PropertyDto( name: 'prompt', - description: 'The prompt the user is using the search for.', + description: 'This is the prompt the user is using to search the database and may or may not assist the results.', type: 'string', - required: true, + required: false, ), ]; } diff --git a/app/LlmDriver/Functions/SummarizeCollection.php b/app/LlmDriver/Functions/SummarizeCollection.php new file mode 100644 index 00000000..b2b79e25 --- /dev/null +++ b/app/LlmDriver/Functions/SummarizeCollection.php @@ -0,0 +1,46 @@ + '' + ]); + } + + /** + * @return PropertyDto[] + */ + protected function getProperties(): array + { + return [ + new PropertyDto( + name: 'prompt', + description: 'The prompt the user is using the search for.', + type: 'string', + required: true, + ), + ]; + } +} diff --git a/app/LlmDriver/LlmDriverClient.php b/app/LlmDriver/LlmDriverClient.php index d432249f..75f8fadf 100644 --- a/app/LlmDriver/LlmDriverClient.php +++ b/app/LlmDriver/LlmDriverClient.php @@ -3,6 +3,7 @@ namespace App\LlmDriver; use App\LlmDriver\Functions\SearchAndSummarize; +use App\LlmDriver\Functions\SummarizeCollection; class LlmDriverClient { @@ -53,6 +54,7 @@ public function getFunctions(): array { return [ (new SearchAndSummarize())->getFunction(), + (new SummarizeCollection())->getFunction(), ]; } } diff --git a/app/LlmDriver/LlmServiceProvider.php b/app/LlmDriver/LlmServiceProvider.php index 6dbf9e22..c299b5bd 100644 --- a/app/LlmDriver/LlmServiceProvider.php +++ b/app/LlmDriver/LlmServiceProvider.php @@ -2,6 +2,8 @@ namespace App\LlmDriver; +use App\LlmDriver\Functions\SearchAndSummarize; +use App\LlmDriver\Functions\SummarizeCollection; use Illuminate\Support\ServiceProvider; class LlmServiceProvider extends ServiceProvider @@ -23,5 +25,13 @@ public function boot(): void return new LlmDriverClient(); }); + $this->app->bind('summarize_collection', function () { + return new SummarizeCollection(); + }); + + $this->app->bind('search_and_summarize', function() { + return new SearchAndSummarize(); + }); + } } diff --git a/app/LlmDriver/Orchestrate.php b/app/LlmDriver/Orchestrate.php new file mode 100644 index 00000000..8fbe22de --- /dev/null +++ b/app/LlmDriver/Orchestrate.php @@ -0,0 +1,118 @@ +chatable->getDriver()) + ->functionPromptChat($messagesArray); + + if(!empty($functions)) { + /** + * @TODO + * We will deal with multi functions shortly + * + * @TODO + * When should messages be made + * which class should make them + * In this case I will assume the user of this class + * save the Users input as a Message already + */ + foreach($functions as $function) { + $functionName = data_get($function, 'name', null); + + if(is_null($functionName)) { + throw new \Exception("Function name is required"); + } + + ChatUiUpdateEvent::dispatch( + $chat->chatable, + $chat, + sprintf("We are running the agent %s back shortly", + str($functionName)->headline()->toString() + ) + ); + + $functionClass = app()->make($functionName); + + $arguments = data_get($function, 'arguments'); + + $arguments = is_array($arguments) ? json_encode($arguments) : ""; + + $functionDto = FunctionCallDto::from([ + 'arguments' => $arguments, + 'function_name' => $functionName + ]); + + $response = $functionClass->handle($messagesArray, $chat, $functionDto); + + $chat->addInput( + message: $response->content, + role: RoleEnum::Assistant, + show_in_thread: false); + + $messagesArray[] = MessageInDto::from([ + 'role' => 'assistant', + 'content' => $response->content + ]); + + ChatUiUpdateEvent::dispatch( + $chat->chatable, + $chat, + "The Agent has completed the task going to the final step now"); + } + + + $results = LlmDriverFacade::driver($chat->chatable->getDriver()) + ->chat($messagesArray); + + + $chat->addInput( + message: $results->content, + role: RoleEnum::Assistant, + show_in_thread: true); + + /** + * Could just show this in the ui + */ + ChatUiUpdateEvent::dispatch( + $chat->chatable, + $chat, + $results->content); + + return $results->content; + } else { + /** + * @NOTE + * this assumes way too much + */ + $message = collect($messagesArray)->first( + function($message) { + return $message->role === 'user'; + } + )->content; + + return SearchOrSummarizeChatRepo::search($chat, $message); + } + } + +} \ No newline at end of file diff --git a/app/LlmDriver/Responses/FunctionResponse.php b/app/LlmDriver/Responses/FunctionResponse.php new file mode 100644 index 00000000..860fcfd6 --- /dev/null +++ b/app/LlmDriver/Responses/FunctionResponse.php @@ -0,0 +1,11 @@ + { router.reload({ preserveScroll: true, }) + }) + .listen('.update', (e) => { + console.log(e); + // Make a better ui for htis + toast.success(e.updateMessage); }); }); diff --git a/tests/Feature/FunctionCallDtoTest.php b/tests/Feature/FunctionCallDtoTest.php new file mode 100644 index 00000000..4780e4a3 --- /dev/null +++ b/tests/Feature/FunctionCallDtoTest.php @@ -0,0 +1,25 @@ + json_encode(['TLDR it for me']), + 'function_name' => 'summarize_collection', + ]); + + $this->assertIsArray($dto->arguments); + $this->assertNotEmpty($dto->arguments); + } +} diff --git a/tests/Feature/Http/Controllers/ChatControllerTest.php b/tests/Feature/Http/Controllers/ChatControllerTest.php index 1ffdd221..c3d720b7 100644 --- a/tests/Feature/Http/Controllers/ChatControllerTest.php +++ b/tests/Feature/Http/Controllers/ChatControllerTest.php @@ -2,10 +2,15 @@ namespace Tests\Feature\Http\Controllers; +use App\LlmDriver\LlmDriverFacade; +use Facades\App\LlmDriver\Orchestrate; +use Facades\App\LlmDriver\MockClient; +use Facades\App\Domains\Messages\SearchOrSummarizeChatRepo; use App\Models\Chat; use App\Models\Collection; use App\Models\Message; use App\Models\User; +use Illuminate\Support\Facades\Event; use Tests\TestCase; class ChatControllerTest extends TestCase @@ -25,6 +30,27 @@ public function test_can_create_chat_and_redirect(): void $this->assertDatabaseCount('chats', 1); } + public function test_a_function_based_chat() + { + $user = User::factory()->create(); + $collection = Collection::factory()->create(); + $chat = Chat::factory()->create([ + 'chatable_id' => $collection->id, + 'chatable_type' => Collection::class, + 'user_id' => $user->id, + ]); + + Orchestrate::shouldReceive('handle')->once()->andReturn("Yo"); + + $this->actingAs($user)->post(route('chats.messages.create', [ + 'chat' => $chat->id, + ]), + [ + 'system_prompt' => 'Foo', + 'input' => 'user input', + ])->assertOk(); + } + public function test_kick_off_chat_makes_system() { $user = User::factory()->create(); diff --git a/tests/Feature/OrchestrateTest.php b/tests/Feature/OrchestrateTest.php new file mode 100644 index 00000000..70642ced --- /dev/null +++ b/tests/Feature/OrchestrateTest.php @@ -0,0 +1,77 @@ +functionPromptChat')->once()->andReturn([ + [ + 'name' => 'summarize_collection', + 'arguments' => [ + "TLDR it for me" + ] + ], + ]); + + LlmDriverFacade::shouldReceive('driver->chat')->once()->andReturn( + CompletionResponse::from([ + 'content' => "Summarized" + ]) + ); + SearchOrSummarizeChatRepo::shouldReceive('search')->never(); + + $this->instance( + 'summarize_collection', + Mockery::mock(SummarizeCollection::class, function ($mock) { + $mock->shouldReceive('handle') + ->once() + ->andReturn( + FunctionResponse::from(['content' => 'This is the summary of the collection']) + ); + }) + ); + + $user = User::factory()->create(); + $collection = Collection::factory()->create(); + $chat = Chat::factory()->create([ + 'chatable_id' => $collection->id, + 'chatable_type' => Collection::class, + 'user_id' => $user->id, + ]); + + $messageDto = MessageInDto::from([ + 'content' => 'TLDR it for me', + 'role' => 'user', + ]); + + $results = (new Orchestrate())->handle([$messageDto], $chat); + + Event::assertDispatched(ChatUiUpdateEvent::class); + + $this->assertEquals($results, 'Summarized'); + + } +} diff --git a/tests/Feature/SummarizeCollectionTest.php b/tests/Feature/SummarizeCollectionTest.php new file mode 100644 index 00000000..e69856b7 --- /dev/null +++ b/tests/Feature/SummarizeCollectionTest.php @@ -0,0 +1,26 @@ +getFunction(); + + $parameters = $function->parameters; + + $this->assertInstanceOf(ParametersDto::class, $parameters); + $this->assertIsArray($parameters->properties); + $this->assertInstanceOf(PropertyDto::class, $parameters->properties[0]); + } +} From 6b148b8b5ede498f71013945a16b5240692a559d Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Fri, 5 Apr 2024 10:04:11 -0400 Subject: [PATCH 05/10] ok stab 1 at this but I think a while loop could be better --- app/Http/Controllers/ChatController.php | 11 +- app/LlmDriver/BaseClient.php | 8 +- app/LlmDriver/Functions/ArgumentCaster.php | 2 - app/LlmDriver/Functions/FunctionCallDto.php | 1 - app/LlmDriver/Functions/FunctionContract.php | 9 +- .../Functions/SearchAndSummarize.php | 5 +- .../Functions/SummarizeCollection.php | 42 ++++-- app/LlmDriver/LlmServiceProvider.php | 2 +- app/LlmDriver/OpenAiClient.php | 49 ++++++- app/LlmDriver/Orchestrate.php | 91 ++++++------ app/LlmDriver/Responses/FunctionResponse.php | 8 +- app/Models/Chat.php | 16 ++- composer.json | 5 +- composer.lock | 133 ++++++++++++------ resources/js/Pages/Chat/ChatBaloon.vue | 8 +- resources/js/Pages/Chat/ChatInputThreaded.vue | 17 +++ resources/js/Pages/Chat/ChatUi.vue | 4 +- resources/js/Pages/Collection/Chat.vue | 15 +- resources/js/app.js | 6 +- tests/Feature/FunctionCallDtoTest.php | 2 - .../Http/Controllers/ChatControllerTest.php | 12 +- tests/Feature/OpenAiClientTest.php | 49 +++++++ tests/Feature/OrchestrateTest.php | 32 ++--- tests/Feature/SummarizeCollectionTest.php | 54 +++++++ ...e_with_functions_summarize_collection.json | 33 +++++ 25 files changed, 455 insertions(+), 159 deletions(-) create mode 100644 tests/fixtures/openai_response_with_functions_summarize_collection.json diff --git a/app/Http/Controllers/ChatController.php b/app/Http/Controllers/ChatController.php index 2c36e6ec..b0e89681 100644 --- a/app/Http/Controllers/ChatController.php +++ b/app/Http/Controllers/ChatController.php @@ -2,16 +2,15 @@ namespace App\Http\Controllers; -use App\Events\ChatUiUpdateEvent; +use App\Domains\Messages\RoleEnum; use App\Events\ChatUpdatedEvent; use App\Http\Resources\ChatResource; use App\Http\Resources\CollectionResource; use App\Http\Resources\MessageResource; -use App\LlmDriver\LlmDriverFacade; -use Facades\App\LlmDriver\Orchestrate; use App\LlmDriver\Requests\MessageInDto; use App\Models\Chat; use App\Models\Collection; +use Facades\App\LlmDriver\Orchestrate; class ChatController extends Controller { @@ -47,6 +46,11 @@ public function chat(Chat $chat) 'input' => 'required|string', ]); + $chat->addInput( + message: $validated['input'], + role: RoleEnum::User, + show_in_thread: true); + $messagesArray = []; $messagesArray[] = MessageInDto::from([ @@ -56,7 +60,6 @@ public function chat(Chat $chat) $response = Orchestrate::handle($messagesArray, $chat); - ChatUpdatedEvent::dispatch($chat->chatable, $chat); return response()->json(['message' => $response]); diff --git a/app/LlmDriver/BaseClient.php b/app/LlmDriver/BaseClient.php index b5e25a26..3cd53617 100644 --- a/app/LlmDriver/BaseClient.php +++ b/app/LlmDriver/BaseClient.php @@ -41,7 +41,7 @@ public function functionPromptChat(array $messages, array $only = []): array Log::info('LlmDriver::MockClient::functionPromptChat', $messages); - $data = get_fixture('openai_response_with_functions.json'); + $data = get_fixture('openai_response_with_functions_summarize_collection.json'); $functions = []; @@ -60,7 +60,7 @@ public function functionPromptChat(array $messages, array $only = []): array } /** - * @TODO + * @TODO * make this a dto */ return $functions; @@ -77,9 +77,7 @@ public function chat(array $messages): CompletionResponse Log::info('LlmDriver::MockClient::completion'); - $data = <<<'EOD' - Voluptate irure cillum dolor anim officia reprehenderit dolor. Eiusmod veniam nostrud consectetur incididunt proident id. Anim adipisicing pariatur amet duis Lorem sunt veniam veniam est. Deserunt ea aliquip cillum pariatur consectetur. Dolor in reprehenderit adipisicing consectetur cupidatat ad cupidatat reprehenderit. Nostrud mollit voluptate aliqua anim pariatur excepteur eiusmod velit quis exercitation tempor quis excepteur. -EOD; + $data = fake()->sentences(3, true); return new CompletionResponse($data); } diff --git a/app/LlmDriver/Functions/ArgumentCaster.php b/app/LlmDriver/Functions/ArgumentCaster.php index 1dbaa2ed..a3b2a97d 100644 --- a/app/LlmDriver/Functions/ArgumentCaster.php +++ b/app/LlmDriver/Functions/ArgumentCaster.php @@ -8,10 +8,8 @@ class ArgumentCaster implements Cast { - public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): array { return json_decode($value, true); } - } diff --git a/app/LlmDriver/Functions/FunctionCallDto.php b/app/LlmDriver/Functions/FunctionCallDto.php index dbbfaf6d..48c7bca2 100644 --- a/app/LlmDriver/Functions/FunctionCallDto.php +++ b/app/LlmDriver/Functions/FunctionCallDto.php @@ -2,7 +2,6 @@ namespace App\LlmDriver\Functions; -use App\LlmDriver\Functions\ArgumentCaster; use Spatie\LaravelData\Attributes\WithCast; class FunctionCallDto extends \Spatie\LaravelData\Data diff --git a/app/LlmDriver/Functions/FunctionContract.php b/app/LlmDriver/Functions/FunctionContract.php index b47ed21f..19cc91a9 100644 --- a/app/LlmDriver/Functions/FunctionContract.php +++ b/app/LlmDriver/Functions/FunctionContract.php @@ -2,10 +2,9 @@ namespace App\LlmDriver\Functions; - -use App\Models\Chat; use App\LlmDriver\Requests\MessageInDto; use App\LlmDriver\Responses\FunctionResponse; +use App\Models\Chat; abstract class FunctionContract { @@ -16,11 +15,7 @@ abstract class FunctionContract protected string $type = 'object'; /** - * - * @param MessageInDto[] $messageArray - * @param App\LlmDriver\Functions\Chat $chat - * @param FunctionCallDto $functionCallDto - * @return array + * @param MessageInDto[] $messageArray */ abstract public function handle( array $messageArray, diff --git a/app/LlmDriver/Functions/SearchAndSummarize.php b/app/LlmDriver/Functions/SearchAndSummarize.php index 980c58d6..6c13f609 100644 --- a/app/LlmDriver/Functions/SearchAndSummarize.php +++ b/app/LlmDriver/Functions/SearchAndSummarize.php @@ -2,9 +2,8 @@ namespace App\LlmDriver\Functions; -use App\Models\Chat; -use App\LlmDriver\Requests\MessageInDto; use App\LlmDriver\Responses\FunctionResponse; +use App\Models\Chat; class SearchAndSummarize extends FunctionContract { @@ -19,7 +18,7 @@ public function handle( { return FunctionResponse::from( [ - 'content' => '' + 'content' => '', ] ); } diff --git a/app/LlmDriver/Functions/SummarizeCollection.php b/app/LlmDriver/Functions/SummarizeCollection.php index b2b79e25..74bbc94b 100644 --- a/app/LlmDriver/Functions/SummarizeCollection.php +++ b/app/LlmDriver/Functions/SummarizeCollection.php @@ -2,9 +2,12 @@ namespace App\LlmDriver\Functions; -use App\Models\Chat; +use App\Domains\Messages\RoleEnum; +use App\LlmDriver\LlmDriverFacade; use App\LlmDriver\Requests\MessageInDto; use App\LlmDriver\Responses\FunctionResponse; +use App\Models\Chat; +use Illuminate\Support\Facades\Log; class SummarizeCollection extends FunctionContract { @@ -12,20 +15,41 @@ class SummarizeCollection extends FunctionContract protected string $description = 'This is used when the prompt wants to summarize the entire collection of documents'; - /** - * - * @param MessageInDto[] $messageArray - * @param App\LlmDriver\Functions\Chat $chat - * @param FunctionCallDto $functionCallDto - * @return array - */ public function handle( array $messageArray, Chat $chat, FunctionCallDto $functionCallDto): FunctionResponse { + Log::info('[LaraChain] SummarizeCollection function called'); + + $summary = collect([]); + + foreach ($chat->chatable->documents as $document) { + foreach ($document->document_chunks as $chunk) { + $summary->add($chunk->summary); + } + } + + $summary = $summary->implode('\n'); + + $prompt = 'Can you summarize all of this content for me from a collection of documents I uploaded what follows is the content: '.$summary; + + $messagesArray = []; + + $messagesArray[] = MessageInDto::from([ + 'content' => $prompt, + 'role' => 'user', + ]); + + $results = LlmDriverFacade::driver($chat->getDriver())->chat($messagesArray); + + $chat->addInput( + message: $results->content, + role: RoleEnum::Assistant, + show_in_thread: true); + return FunctionResponse::from([ - 'content' => '' + 'content' => $results->content, ]); } diff --git a/app/LlmDriver/LlmServiceProvider.php b/app/LlmDriver/LlmServiceProvider.php index c299b5bd..d511c904 100644 --- a/app/LlmDriver/LlmServiceProvider.php +++ b/app/LlmDriver/LlmServiceProvider.php @@ -29,7 +29,7 @@ public function boot(): void return new SummarizeCollection(); }); - $this->app->bind('search_and_summarize', function() { + $this->app->bind('search_and_summarize', function () { return new SearchAndSummarize(); }); diff --git a/app/LlmDriver/OpenAiClient.php b/app/LlmDriver/OpenAiClient.php index bc2d324c..5841dea9 100644 --- a/app/LlmDriver/OpenAiClient.php +++ b/app/LlmDriver/OpenAiClient.php @@ -5,6 +5,7 @@ use App\LlmDriver\Requests\MessageInDto; use App\LlmDriver\Responses\CompletionResponse; use App\LlmDriver\Responses\EmbeddingsResponseDto; +use Illuminate\Support\Facades\Log; use OpenAI\Laravel\Facades\OpenAI; class OpenAiClient extends BaseClient @@ -16,15 +17,12 @@ class OpenAiClient extends BaseClient */ public function chat(array $messages): CompletionResponse { - $functions = $this->getFunctions(); $response = OpenAI::chat()->create([ 'model' => $this->getConfig('openai')['chat_model'], 'messages' => collect($messages)->map(function ($message) { return $message->toArray(); })->toArray(), - 'tool_choice' => 'auto', - 'tools' => $functions, ]); $results = null; @@ -74,6 +72,51 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp return new CompletionResponse($results); } + /** + * This is to get functions out of the llm + * if none are returned your system + * can error out or try another way. + * + * @param MessageInDto[] $messages + */ + public function functionPromptChat(array $messages, array $only = []): array + { + + Log::info('LlmDriver::OpenAiClient::functionPromptChat', $messages); + + $functions = $this->getFunctions(); + + $response = OpenAI::chat()->create([ + 'model' => $this->getConfig('openai')['chat_model'], + 'messages' => collect($messages)->map(function ($message) { + return $message->toArray(); + })->toArray(), + 'tool_choice' => 'auto', + 'tools' => $functions, + ]); + + $functions = []; + foreach ($response->choices as $result) { + foreach (data_get($result, 'message.toolCalls', []) as $tool) { + if (data_get($tool, 'type') === 'function') { + $name = data_get($tool, 'function.name', null); + if (! in_array($name, $only)) { + $functions[] = [ + 'name' => $name, + 'arguments' => json_decode(data_get($tool, 'function.arguments', []), true), + ]; + } + } + } + } + + /** + * @TODO + * make this a dto + */ + return $functions; + } + /** * @NOTE * Since this abstraction layer is based on OpenAi diff --git a/app/LlmDriver/Orchestrate.php b/app/LlmDriver/Orchestrate.php index 8fbe22de..f4df90f2 100644 --- a/app/LlmDriver/Orchestrate.php +++ b/app/LlmDriver/Orchestrate.php @@ -1,112 +1,124 @@ -chatable->getDriver()) ->functionPromptChat($messagesArray); - if(!empty($functions)) { + if (! empty($functions)) { /** * @TODO * We will deal with multi functions shortly - * - * @TODO - * When should messages be made + * @TODO + * When should messages be made * which class should make them * In this case I will assume the user of this class * save the Users input as a Message already */ - foreach($functions as $function) { + foreach ($functions as $function) { $functionName = data_get($function, 'name', null); - - if(is_null($functionName)) { - throw new \Exception("Function name is required"); + + if (is_null($functionName)) { + throw new \Exception('Function name is required'); } ChatUiUpdateEvent::dispatch( - $chat->chatable, - $chat, - sprintf("We are running the agent %s back shortly", + $chat->chatable, + $chat, + sprintf('We are running the agent %s back shortly', str($functionName)->headline()->toString() - ) - ); + ) + ); $functionClass = app()->make($functionName); $arguments = data_get($function, 'arguments'); - $arguments = is_array($arguments) ? json_encode($arguments) : ""; + $arguments = is_array($arguments) ? json_encode($arguments) : ''; $functionDto = FunctionCallDto::from([ 'arguments' => $arguments, - 'function_name' => $functionName + 'function_name' => $functionName, ]); + /** @var FunctionResponse $response */ $response = $functionClass->handle($messagesArray, $chat, $functionDto); $chat->addInput( - message: $response->content, + message: $response->content, role: RoleEnum::Assistant, show_in_thread: false); - $messagesArray[] = MessageInDto::from([ + $messagesArray = Arr::wrap(MessageInDto::from([ 'role' => 'assistant', - 'content' => $response->content - ]); + 'content' => $response->content, + ])); ChatUiUpdateEvent::dispatch( $chat->chatable, $chat, - "The Agent has completed the task going to the final step now"); - } + 'The Agent has completed the task going to the final step now'); + $this->response = $response->content; + $this->requiresFollowup = $response->requires_follow_up_prompt; + } + /** + * @NOTE the function might return the results of a table + * or csv file or image info etc. + * This prompt should consider the initial prompt and the output of the function(s) + */ + if ($this->requiresFollowup) { $results = LlmDriverFacade::driver($chat->chatable->getDriver()) ->chat($messagesArray); - $chat->addInput( - message: $results->content, + message: $results->content, role: RoleEnum::Assistant, show_in_thread: true); - + /** * Could just show this in the ui - */ + */ ChatUiUpdateEvent::dispatch( $chat->chatable, $chat, $results->content); - return $results->content; + $this->response = $results->content; + } + + return $this->response; } else { /** - * @NOTE + * @NOTE * this assumes way too much */ $message = collect($messagesArray)->first( - function($message) { + function ($message) { return $message->role === 'user'; } )->content; @@ -114,5 +126,4 @@ function($message) { return SearchOrSummarizeChatRepo::search($chat, $message); } } - -} \ No newline at end of file +} diff --git a/app/LlmDriver/Responses/FunctionResponse.php b/app/LlmDriver/Responses/FunctionResponse.php index 860fcfd6..a140223e 100644 --- a/app/LlmDriver/Responses/FunctionResponse.php +++ b/app/LlmDriver/Responses/FunctionResponse.php @@ -2,10 +2,16 @@ namespace App\LlmDriver\Responses; +/** + * @NOTE + * Requires follow up with be for example results of a panda query on a csv file + * maybe more info is needed from an llm or agent + */ class FunctionResponse extends \Spatie\LaravelData\Data { public function __construct( - public string $content + public string $content, + public bool $requires_follow_up_prompt = false ) { } } diff --git a/app/Models/Chat.php b/app/Models/Chat.php index 9f430d9a..f2a2ab2d 100644 --- a/app/Models/Chat.php +++ b/app/Models/Chat.php @@ -3,6 +3,7 @@ namespace App\Models; use App\Domains\Messages\RoleEnum; +use App\LlmDriver\HasDrivers; use App\LlmDriver\Requests\MessageInDto; use Illuminate\Database\Eloquent\Factories\HasFactory; use Illuminate\Database\Eloquent\Model; @@ -14,11 +15,22 @@ /** * @property mixed $chatable; */ -class Chat extends Model +class Chat extends Model implements HasDrivers { use HasFactory; - protected $fillable = []; + protected $guarded = []; + + public function getDriver(): string + { + + return $this->chatable->getDriver(); + } + + public function getEmbeddingDriver(): string + { + return $this->chatable->getEmbeddingDriver(); + } protected function createSystemMessageIfNeeded(string $systemPrompt): void { diff --git a/composer.json b/composer.json index 193d6583..8883b071 100644 --- a/composer.json +++ b/composer.json @@ -8,6 +8,7 @@ "php": "^8.2", "ankane/pgvector": "^0.1.3", "archtechx/enums": "^0.3.2", + "fakerphp/faker": "^1.23", "inertiajs/inertia-laravel": "^1.0", "laravel/framework": "^11.0", "laravel/horizon": "^5.23", @@ -24,10 +25,10 @@ "spatie/laravel-markdown": "^2.3", "tightenco/ziggy": "^2.0", "voku/stop-words": "^2.0", - "wamania/php-stemmer": "^3.0" + "wamania/php-stemmer": "^3.0", + "yethee/tiktoken": "^0.3.0" }, "require-dev": { - "fakerphp/faker": "^1.23", "larastan/larastan": "^2.9", "laravel/pint": "^1.13", "laravel/sail": "^1.26", diff --git a/composer.lock b/composer.lock index 99fb3b20..a6c1f4a8 100644 --- a/composer.lock +++ b/composer.lock @@ -4,7 +4,7 @@ "Read more about it at https://getcomposer.org/doc/01-basic-usage.md#installing-dependencies", "This file is @generated automatically" ], - "content-hash": "1313d049706577645a0e754579b74eaa", + "content-hash": "5a7b8196ee2ce364300f100ffe0212a0", "packages": [ { "name": "amphp/amp", @@ -1949,6 +1949,69 @@ }, "time": "2023-08-08T05:53:35+00:00" }, + { + "name": "fakerphp/faker", + "version": "v1.23.1", + "source": { + "type": "git", + "url": "https://github.com/FakerPHP/Faker.git", + "reference": "bfb4fe148adbf78eff521199619b93a52ae3554b" + }, + "dist": { + "type": "zip", + "url": "https://api.github.com/repos/FakerPHP/Faker/zipball/bfb4fe148adbf78eff521199619b93a52ae3554b", + "reference": "bfb4fe148adbf78eff521199619b93a52ae3554b", + "shasum": "" + }, + "require": { + "php": "^7.4 || ^8.0", + "psr/container": "^1.0 || ^2.0", + "symfony/deprecation-contracts": "^2.2 || ^3.0" + }, + "conflict": { + "fzaninotto/faker": "*" + }, + "require-dev": { + "bamarni/composer-bin-plugin": "^1.4.1", + "doctrine/persistence": "^1.3 || ^2.0", + "ext-intl": "*", + "phpunit/phpunit": "^9.5.26", + "symfony/phpunit-bridge": "^5.4.16" + }, + "suggest": { + "doctrine/orm": "Required to use Faker\\ORM\\Doctrine", + "ext-curl": "Required by Faker\\Provider\\Image to download images.", + "ext-dom": "Required by Faker\\Provider\\HtmlLorem for generating random HTML.", + "ext-iconv": "Required by Faker\\Provider\\ru_RU\\Text::realText() for generating real Russian text.", + "ext-mbstring": "Required for multibyte Unicode string functionality." + }, + "type": "library", + "autoload": { + "psr-4": { + "Faker\\": "src/Faker/" + } + }, + "notification-url": "https://packagist.org/downloads/", + "license": [ + "MIT" + ], + "authors": [ + { + "name": "François Zaninotto" + } + ], + "description": "Faker is a PHP library that generates fake data for you.", + "keywords": [ + "data", + "faker", + "fixtures" + ], + "support": { + "issues": "https://github.com/FakerPHP/Faker/issues", + "source": "https://github.com/FakerPHP/Faker/tree/v1.23.1" + }, + "time": "2024-01-02T13:46:09+00:00" + }, { "name": "fruitcake/php-cors", "version": "v1.3.0", @@ -10562,72 +10625,58 @@ "source": "https://github.com/webmozarts/assert/tree/1.11.0" }, "time": "2022-06-03T18:03:27+00:00" - } - ], - "packages-dev": [ + }, { - "name": "fakerphp/faker", - "version": "v1.23.1", + "name": "yethee/tiktoken", + "version": "0.3.0", "source": { "type": "git", - "url": "https://github.com/FakerPHP/Faker.git", - "reference": "bfb4fe148adbf78eff521199619b93a52ae3554b" + "url": "https://github.com/yethee/tiktoken-php.git", + "reference": "c84f066dcb0cff60685537c71bf795967775446a" }, "dist": { "type": "zip", - "url": "https://api.github.com/repos/FakerPHP/Faker/zipball/bfb4fe148adbf78eff521199619b93a52ae3554b", - "reference": "bfb4fe148adbf78eff521199619b93a52ae3554b", + "url": "https://api.github.com/repos/yethee/tiktoken-php/zipball/c84f066dcb0cff60685537c71bf795967775446a", + "reference": "c84f066dcb0cff60685537c71bf795967775446a", "shasum": "" }, "require": { - "php": "^7.4 || ^8.0", - "psr/container": "^1.0 || ^2.0", - "symfony/deprecation-contracts": "^2.2 || ^3.0" - }, - "conflict": { - "fzaninotto/faker": "*" + "php": "^8.1", + "symfony/service-contracts": "^2.5 || ^3.0" }, "require-dev": { - "bamarni/composer-bin-plugin": "^1.4.1", - "doctrine/persistence": "^1.3 || ^2.0", - "ext-intl": "*", - "phpunit/phpunit": "^9.5.26", - "symfony/phpunit-bridge": "^5.4.16" - }, - "suggest": { - "doctrine/orm": "Required to use Faker\\ORM\\Doctrine", - "ext-curl": "Required by Faker\\Provider\\Image to download images.", - "ext-dom": "Required by Faker\\Provider\\HtmlLorem for generating random HTML.", - "ext-iconv": "Required by Faker\\Provider\\ru_RU\\Text::realText() for generating real Russian text.", - "ext-mbstring": "Required for multibyte Unicode string functionality." + "doctrine/coding-standard": "^12.0", + "phpunit/phpunit": "^10.3", + "psalm/plugin-phpunit": "^0.18.3", + "vimeo/psalm": "5.19.0" }, "type": "library", "autoload": { "psr-4": { - "Faker\\": "src/Faker/" + "Yethee\\Tiktoken\\": "src" } }, "notification-url": "https://packagist.org/downloads/", "license": [ "MIT" ], - "authors": [ - { - "name": "François Zaninotto" - } - ], - "description": "Faker is a PHP library that generates fake data for you.", + "description": "PHP version of tiktoken", "keywords": [ - "data", - "faker", - "fixtures" + "bpe", + "decode", + "encode", + "openai", + "tiktoken", + "tokenizer" ], "support": { - "issues": "https://github.com/FakerPHP/Faker/issues", - "source": "https://github.com/FakerPHP/Faker/tree/v1.23.1" + "issues": "https://github.com/yethee/tiktoken-php/issues", + "source": "https://github.com/yethee/tiktoken-php/tree/0.3.0" }, - "time": "2024-01-02T13:46:09+00:00" - }, + "time": "2024-01-10T10:34:57+00:00" + } + ], + "packages-dev": [ { "name": "filp/whoops", "version": "2.15.4", diff --git a/resources/js/Pages/Chat/ChatBaloon.vue b/resources/js/Pages/Chat/ChatBaloon.vue index ca0a3c65..01986fe0 100644 --- a/resources/js/Pages/Chat/ChatBaloon.vue +++ b/resources/js/Pages/Chat/ChatBaloon.vue @@ -38,18 +38,16 @@ const props = defineProps({ -
+
- -
diff --git a/resources/js/Pages/Chat/ChatInputThreaded.vue b/resources/js/Pages/Chat/ChatInputThreaded.vue index 5ae21572..99ac42f5 100644 --- a/resources/js/Pages/Chat/ChatInputThreaded.vue +++ b/resources/js/Pages/Chat/ChatInputThreaded.vue @@ -56,7 +56,24 @@ const setQuestion = (question) => {