diff --git a/app/Domains/Messages/SearchOrSummarizeChatRepo.php b/app/Domains/Messages/SearchOrSummarizeChatRepo.php index d91a6f97..bd5df8e8 100644 --- a/app/Domains/Messages/SearchOrSummarizeChatRepo.php +++ b/app/Domains/Messages/SearchOrSummarizeChatRepo.php @@ -23,7 +23,7 @@ public function search(Chat $chat, string $input): string /** @var EmbeddingsResponseDto $embedding */ $embedding = LlmDriverFacade::driver( - $chat->chatable->getDriver() + $chat->chatable->getEmddingDriver() )->embedData($input); $results = DocumentChunk::query() diff --git a/app/Http/Controllers/CollectionController.php b/app/Http/Controllers/CollectionController.php index d79565d4..e92c7bb4 100644 --- a/app/Http/Controllers/CollectionController.php +++ b/app/Http/Controllers/CollectionController.php @@ -31,6 +31,7 @@ public function store() 'name' => 'required', 'description' => 'required', 'driver' => 'required', + 'embedding_driver' => 'required', ]); $validated['team_id'] = auth()->user()->current_team_id; diff --git a/app/Jobs/VectorlizeDataJob.php b/app/Jobs/VectorlizeDataJob.php index 868762a5..97811ab6 100644 --- a/app/Jobs/VectorlizeDataJob.php +++ b/app/Jobs/VectorlizeDataJob.php @@ -43,7 +43,7 @@ public function handle(): void $content = $this->documentChunk->content; /** @var EmbeddingsResponseDto $results */ - $results = LlmDriverFacade::driver($this->documentChunk->getDriver()) + $results = LlmDriverFacade::driver($this->documentChunk->getEmbeddingDriver()) ->embedData($content); $this->documentChunk->update([ diff --git a/app/LlmDriver/ClaudeClient.php b/app/LlmDriver/ClaudeClient.php index 183525bb..5b42416d 100644 --- a/app/LlmDriver/ClaudeClient.php +++ b/app/LlmDriver/ClaudeClient.php @@ -23,8 +23,6 @@ public function embedData(string $data): EmbeddingsResponseDto Log::info('LlmDriver::ClaudeClient::embedData'); - - return EmbeddingsResponseDto::from([ 'embedding' => data_get($data, 'data.0.embedding'), 'token_count' => 1000, @@ -36,17 +34,40 @@ public function embedData(string $data): EmbeddingsResponseDto */ public function chat(array $messages): CompletionResponse { - if (! app()->environment('testing')) { - sleep(2); + $model = $this->getConfig('claude')['models']['completion_model']; + $maxTokens = $this->getConfig('claude')['max_tokens']; + + Log::info('LlmDriver::Claude::completion'); + + $messages = collect($messages)->map(function($item) { + if($item->role === 'system') { + $item->role = 'assistant'; + } + + return $item->toArray(); + })->reverse()->values()->all(); + + $results = $this->getClient()->post('/messages', [ + 'model' => $model, + "max_tokens" => $maxTokens, + 'messages' => $messages, + ]); + + if(!$results->ok()) { + $error = $results->json()['error']['type']; + Log::error('Claude API Error ' . $error); + throw new \Exception('Claude API Error ' . $error); } - Log::info('LlmDriver::MockClient::completion'); + $data = null; - $data = <<<'EOD' - Voluptate irure cillum dolor anim officia reprehenderit dolor. Eiusmod veniam nostrud consectetur incididunt proident id. Anim adipisicing pariatur amet duis Lorem sunt veniam veniam est. Deserunt ea aliquip cillum pariatur consectetur. Dolor in reprehenderit adipisicing consectetur cupidatat ad cupidatat reprehenderit. Nostrud mollit voluptate aliqua anim pariatur excepteur eiusmod velit quis exercitation tempor quis excepteur. -EOD; + foreach($results->json()['content'] as $content) { + $data = $content['text']; + } - return new CompletionResponse($data); + return CompletionResponse::from([ + 'content' => $data, + ]); } public function completion(string $prompt): CompletionResponse diff --git a/app/LlmDriver/DriversEnum.php b/app/LlmDriver/DriversEnum.php new file mode 100644 index 00000000..0ce69dec --- /dev/null +++ b/app/LlmDriver/DriversEnum.php @@ -0,0 +1,13 @@ + 'boolean', + 'driver' => DriversEnum::class, + 'embedding_driver' => DriversEnum::class ]; + public function team(): BelongsTo { return $this->belongsTo(Team::class); @@ -36,7 +43,12 @@ public function team(): BelongsTo public function getDriver(): string { - return $this->driver; + return $this->driver->value; + } + + public function getEmbeddingDriver(): string + { + return $this->embedding_driver->value; } public function documents(): HasMany diff --git a/app/Models/Document.php b/app/Models/Document.php index c436533b..3fac048c 100644 --- a/app/Models/Document.php +++ b/app/Models/Document.php @@ -4,6 +4,7 @@ use App\Domains\Documents\StatusEnum; use App\Domains\Documents\TypesEnum; +use App\LlmDriver\HasDrivers; use Illuminate\Database\Eloquent\Factories\HasFactory; use Illuminate\Database\Eloquent\Model; use Illuminate\Database\Eloquent\Relations\BelongsTo; @@ -17,7 +18,7 @@ * @property string|null $summary * @property string|null $file_path */ -class Document extends Model +class Document extends Model implements HasDrivers { use HasFactory; @@ -29,6 +30,7 @@ class Document extends Model 'summary_status' => StatusEnum::class, ]; + public function collection(): BelongsTo { return $this->belongsTo(Collection::class); @@ -58,6 +60,11 @@ public function mkdirPathToFile(): ?string public function getDriver(): string { - return $this->collection->driver; + return $this->collection->driver->value; + } + + + public function getEmbeddingDriver(): string { + return $this->collection->embedding_driver->value; } } diff --git a/app/Models/DocumentChunk.php b/app/Models/DocumentChunk.php index c9b9d2d7..5c819919 100644 --- a/app/Models/DocumentChunk.php +++ b/app/Models/DocumentChunk.php @@ -3,6 +3,7 @@ namespace App\Models; use App\Domains\Documents\StatusEnum; +use App\LlmDriver\HasDrivers; use Illuminate\Database\Eloquent\Factories\HasFactory; use Illuminate\Database\Eloquent\Model; use Pgvector\Laravel\Vector; @@ -10,7 +11,7 @@ /** * @property Document $document */ -class DocumentChunk extends Model +class DocumentChunk extends Model implements HasDrivers { use HasFactory; @@ -41,8 +42,13 @@ protected static function booted() }); } + public function getEmbeddingDriver(): string + { + return $this->document->collection->embedding_driver->value; + } + public function getDriver(): string { - return $this->document->collection->driver; + return $this->document->collection->driver->value; } } diff --git a/config/llmdriver.php b/config/llmdriver.php index ad74dc29..e56c26e7 100644 --- a/config/llmdriver.php +++ b/config/llmdriver.php @@ -19,8 +19,9 @@ ], 'claude' => [ 'api_key' => env('CLAUDE_API_KEY'), - 'max_tokens' => env('CLAUDE_MAX_TOKENS', 1024), + 'max_tokens' => env('CLAUDE_MAX_TOKENS', 4000), 'models' => [ + //@see https://www.anthropic.com/news/claude-3-family 'completion_model' => env('CLAUDE_COMPLETION_MODEL', 'claude-3-opus-20240229'), ] ], diff --git a/database/factories/CollectionFactory.php b/database/factories/CollectionFactory.php index ec0de092..498246a2 100644 --- a/database/factories/CollectionFactory.php +++ b/database/factories/CollectionFactory.php @@ -2,6 +2,7 @@ namespace Database\Factories; +use App\LlmDriver\DriversEnum; use App\Models\Team; use Illuminate\Database\Eloquent\Factories\Factory; @@ -22,6 +23,8 @@ public function definition(): array 'description' => $this->faker->paragraph, 'active' => $this->faker->boolean, 'team_id' => Team::factory(), + 'driver' => DriversEnum::Mock, + 'embedding_driver' => DriversEnum::Mock, ]; } } diff --git a/database/migrations/2024_03_29_200402_add_embedding_driver_to_collections_model.php b/database/migrations/2024_03_29_200402_add_embedding_driver_to_collections_model.php new file mode 100644 index 00000000..6de31dc0 --- /dev/null +++ b/database/migrations/2024_03_29_200402_add_embedding_driver_to_collections_model.php @@ -0,0 +1,29 @@ +string('embedding_driver')->default(DriversEnum::Mock); + }); + } + + /** + * Reverse the migrations. + */ + public function down(): void + { + Schema::table('collections', function (Blueprint $table) { + // + }); + } +}; diff --git a/resources/js/Pages/Collection/Components/EmbeddingType.vue b/resources/js/Pages/Collection/Components/EmbeddingType.vue new file mode 100644 index 00000000..f8239bb6 --- /dev/null +++ b/resources/js/Pages/Collection/Components/EmbeddingType.vue @@ -0,0 +1,66 @@ + + + \ No newline at end of file diff --git a/resources/js/Pages/Collection/Components/LlmType.vue b/resources/js/Pages/Collection/Components/LlmType.vue index da86f36e..2ad45dfa 100644 --- a/resources/js/Pages/Collection/Components/LlmType.vue +++ b/resources/js/Pages/Collection/Components/LlmType.vue @@ -1,6 +1,5 @@