From e53dcaa00e381fcee39423a47fcaa53af95a7fc7 Mon Sep 17 00:00:00 2001 From: Alfred Nutile Date: Tue, 26 Mar 2024 21:46:48 -0400 Subject: [PATCH] driver working --- app/Jobs/SummarizeDataJob.php | 4 +- app/Jobs/SummarizeDocumentJob.php | 4 +- app/Jobs/VectorlizeDataJob.php | 5 ++- app/LlmDriver/LlmDriverClient.php | 37 ++++++++++++------- app/LlmDriver/MockClient.php | 4 -- app/LlmDriver/OpenAiClient.php | 17 ++++----- app/Models/Document.php | 5 +++ app/Models/DocumentChunk.php | 9 +++++ app/Providers/AppServiceProvider.php | 2 +- ...03_27_005749_add_driver_to_collections.php | 28 ++++++++++++++ ..._change_embeddings_on_document_chunks.php} | 0 tests/Feature/Jobs/VectorlizeDataJobTest.php | 2 +- tests/Feature/LlmDriverClientTest.php | 28 ++++++++++++++ tests/Feature/LlmDriverFacadeTest.php | 2 +- tests/Feature/OpenAiClientTest.php | 13 +++---- tests/Feature/SummarizeDataJobTest.php | 2 +- 16 files changed, 119 insertions(+), 43 deletions(-) create mode 100644 database/migrations/2024_03_27_005749_add_driver_to_collections.php rename database/migrations/{2024_03_26_235813_change_model_size.php => 2024_03_27_013404_change_embeddings_on_document_chunks.php} (100%) create mode 100644 tests/Feature/LlmDriverClientTest.php diff --git a/app/Jobs/SummarizeDataJob.php b/app/Jobs/SummarizeDataJob.php index caa55a31..91757bf2 100644 --- a/app/Jobs/SummarizeDataJob.php +++ b/app/Jobs/SummarizeDataJob.php @@ -49,7 +49,9 @@ public function handle(): void EOD; /** @var CompletionResponse $results */ - $results = LlmDriverFacade::completion($prompt); + $results = LlmDriverFacade::driver( + $this->documentChunk->getDriver() + )->completion($prompt); $this->documentChunk->update([ 'summary' => $results->content, diff --git a/app/Jobs/SummarizeDocumentJob.php b/app/Jobs/SummarizeDocumentJob.php index 00746700..095f6c5c 100644 --- a/app/Jobs/SummarizeDocumentJob.php +++ b/app/Jobs/SummarizeDocumentJob.php @@ -52,7 +52,9 @@ public function handle(): void EOD; /** @var CompletionResponse $results */ - $results = LlmDriverFacade::completion($prompt); + $results = LlmDriverFacade::driver( + $this->document->getDriver() + )->completion($prompt); $this->document->update([ 'summary' => $results->content, diff --git a/app/Jobs/VectorlizeDataJob.php b/app/Jobs/VectorlizeDataJob.php index 1399866a..868762a5 100644 --- a/app/Jobs/VectorlizeDataJob.php +++ b/app/Jobs/VectorlizeDataJob.php @@ -31,7 +31,7 @@ public function __construct(public DocumentChunk $documentChunk) public function handle(): void { - if ($this->batch()->cancelled()) { + if (optional($this->batch())->cancelled()) { // Determine if the batch has been cancelled... $this->documentChunk->update([ 'status_embeddings' => StatusEnum::Cancelled, @@ -43,7 +43,8 @@ public function handle(): void $content = $this->documentChunk->content; /** @var EmbeddingsResponseDto $results */ - $results = LlmDriverFacade::embedData($content); + $results = LlmDriverFacade::driver($this->documentChunk->getDriver()) + ->embedData($content); $this->documentChunk->update([ 'embedding' => $results->embedding, diff --git a/app/LlmDriver/LlmDriverClient.php b/app/LlmDriver/LlmDriverClient.php index d920c612..e651906e 100644 --- a/app/LlmDriver/LlmDriverClient.php +++ b/app/LlmDriver/LlmDriverClient.php @@ -4,29 +4,38 @@ class LlmDriverClient { - protected $config = []; + protected $drivers = []; - public function __construct(array $config) + public function driver($name = null) { - $this->config = $config; - } + $name = $name ?: $this->getDefaultDriver(); - public static function make(): BaseClient - { - $driver = config('llmdriver.driver'); + if (! isset($this->drivers[$name])) { + $this->drivers[$name] = $this->createDriver($name); + } - $config = config("llmdriver.drivers.{$driver}"); + return $this->drivers[$name]; + } - if (! method_exists(static::class, $driver)) { - throw new \Exception("Driver {$driver} not found"); + protected function createDriver($name) + { + switch ($name) { + case 'openai': + return new OpenAiClient(); + case 'mock': + return new MockClient(); + default: + throw new \InvalidArgumentException("Driver [{$name}] is not supported."); } + } - /** @phpstan-ignore-next-line */ - return (new static($config))->$driver(); + public static function getDrivers(): array + { + return array_keys(config('llmdriver.drivers')); } - public function mock(): BaseClient + protected function getDefaultDriver() { - return new MockClient(); + return 'mock'; } } diff --git a/app/LlmDriver/MockClient.php b/app/LlmDriver/MockClient.php index 7de1f127..48b0da32 100644 --- a/app/LlmDriver/MockClient.php +++ b/app/LlmDriver/MockClient.php @@ -2,10 +2,6 @@ namespace App\LlmDriver; -use App\LlmDriver\Responses\CompletionResponse; -use Illuminate\Support\Facades\Log; - class MockClient extends BaseClient { - } diff --git a/app/LlmDriver/OpenAiClient.php b/app/LlmDriver/OpenAiClient.php index f1e7731b..092bc6f7 100644 --- a/app/LlmDriver/OpenAiClient.php +++ b/app/LlmDriver/OpenAiClient.php @@ -1,10 +1,9 @@ -embeddings as $embedding) { - $embedding->object; // 'embedding' $results = $embedding->embedding; // [0.018990106880664825, -0.0073809814639389515, ...] - $embedding->index; // 0 } return new EmbeddingsResponseDto( @@ -35,19 +32,19 @@ public function embedData(string $data): EmbeddingsResponseDto public function completion(string $prompt, int $temperature = 0): CompletionResponse { - $response = OpenAI::completions()->create([ + $response = OpenAI::chat()->create([ 'model' => $this->getConfig('openai')['completion_model'], - 'prompt' => $prompt, - 'temperature' => 0 + 'messages' => [ + ['role' => 'user', 'content' => $prompt], + ], ]); $results = null; foreach ($response->choices as $result) { - $results = $result->text; // '\n\nThis is a test' + $results = $result->message->content; } return new CompletionResponse($results); } - -} \ No newline at end of file +} diff --git a/app/Models/Document.php b/app/Models/Document.php index f9e2ac32..c436533b 100644 --- a/app/Models/Document.php +++ b/app/Models/Document.php @@ -55,4 +55,9 @@ public function mkdirPathToFile(): ?string $this->collection_id ); } + + public function getDriver(): string + { + return $this->collection->driver; + } } diff --git a/app/Models/DocumentChunk.php b/app/Models/DocumentChunk.php index afbfc034..40c536c5 100644 --- a/app/Models/DocumentChunk.php +++ b/app/Models/DocumentChunk.php @@ -7,6 +7,10 @@ use Illuminate\Database\Eloquent\Model; use Pgvector\Laravel\Vector; +/** + * @property Document $document + * @package App\Models + */ class DocumentChunk extends Model { use HasFactory; @@ -37,4 +41,9 @@ protected static function booted() $document_chunk->saveQuietly(); }); } + + public function getDriver(): string + { + return $this->document->collection->driver; + } } diff --git a/app/Providers/AppServiceProvider.php b/app/Providers/AppServiceProvider.php index 7c097d9a..e6b8a28b 100644 --- a/app/Providers/AppServiceProvider.php +++ b/app/Providers/AppServiceProvider.php @@ -27,7 +27,7 @@ public function boot(): void }); $this->app->bind('llm_driver', function () { - return LlmDriverClient::make(); + return new LlmDriverClient(); }); } diff --git a/database/migrations/2024_03_27_005749_add_driver_to_collections.php b/database/migrations/2024_03_27_005749_add_driver_to_collections.php new file mode 100644 index 00000000..90b0b7df --- /dev/null +++ b/database/migrations/2024_03_27_005749_add_driver_to_collections.php @@ -0,0 +1,28 @@ +string('driver')->default('mock'); + }); + } + + /** + * Reverse the migrations. + */ + public function down(): void + { + Schema::table('collections', function (Blueprint $table) { + $table->dropColumn('driver'); + }); + } +}; diff --git a/database/migrations/2024_03_26_235813_change_model_size.php b/database/migrations/2024_03_27_013404_change_embeddings_on_document_chunks.php similarity index 100% rename from database/migrations/2024_03_26_235813_change_model_size.php rename to database/migrations/2024_03_27_013404_change_embeddings_on_document_chunks.php diff --git a/tests/Feature/Jobs/VectorlizeDataJobTest.php b/tests/Feature/Jobs/VectorlizeDataJobTest.php index b1472c71..336cc950 100644 --- a/tests/Feature/Jobs/VectorlizeDataJobTest.php +++ b/tests/Feature/Jobs/VectorlizeDataJobTest.php @@ -21,7 +21,7 @@ public function test_gets_data(): void 1000 ); - LlmDriverFacade::shouldReceive('embedData') + LlmDriverFacade::shouldReceive('driver->embedData') ->once() ->andReturn($dto); diff --git a/tests/Feature/LlmDriverClientTest.php b/tests/Feature/LlmDriverClientTest.php new file mode 100644 index 00000000..4a8b7970 --- /dev/null +++ b/tests/Feature/LlmDriverClientTest.php @@ -0,0 +1,28 @@ +assertInstanceOf(MockClient::class, $results); + } + + public function test_driver_openai(): void + { + $results = LlmDriverFacade::driver('openai'); + + $this->assertInstanceOf(OpenAiClient::class, $results); + } +} diff --git a/tests/Feature/LlmDriverFacadeTest.php b/tests/Feature/LlmDriverFacadeTest.php index 37ea07be..c69df827 100644 --- a/tests/Feature/LlmDriverFacadeTest.php +++ b/tests/Feature/LlmDriverFacadeTest.php @@ -13,7 +13,7 @@ class LlmDriverFacadeTest extends TestCase */ public function test_facade(): void { - $results = LlmDriverFacade::embedData('test'); + $results = LlmDriverFacade::driver('mock')->embedData('test'); $this->assertInstanceOf( EmbeddingsResponseDto::class, diff --git a/tests/Feature/OpenAiClientTest.php b/tests/Feature/OpenAiClientTest.php index 626cff19..63e04d38 100644 --- a/tests/Feature/OpenAiClientTest.php +++ b/tests/Feature/OpenAiClientTest.php @@ -2,16 +2,13 @@ namespace Tests\Feature; -use Illuminate\Foundation\Testing\RefreshDatabase; -use Illuminate\Foundation\Testing\WithFaker; -use Tests\TestCase; - use App\LlmDriver\Responses\CompletionResponse; use App\LlmDriver\Responses\EmbeddingsResponseDto; -use Illuminate\Support\Facades\Log; use OpenAI\Laravel\Facades\OpenAI; +use OpenAI\Responses\Chat\CreateResponse as ChatCreateResponse; use OpenAI\Responses\Completions\CreateResponse as CompletionsCreateResponse; use OpenAI\Responses\Embeddings\CreateResponse; +use Tests\TestCase; class OpenAiClientTest extends TestCase { @@ -38,10 +35,12 @@ public function test_openai_client(): void public function test_completion(): void { OpenAI::fake([ - CompletionsCreateResponse::fake([ + ChatCreateResponse::fake([ 'choices' => [ [ - 'choice' => 'awesome!', + 'message' => [ + 'content' => 'awesome!' + ], ], ], ]), diff --git a/tests/Feature/SummarizeDataJobTest.php b/tests/Feature/SummarizeDataJobTest.php index 69176a1d..5d55391a 100644 --- a/tests/Feature/SummarizeDataJobTest.php +++ b/tests/Feature/SummarizeDataJobTest.php @@ -18,7 +18,7 @@ public function test_gets_data(): void $data = 'Foo bar'; $dto = new \App\LlmDriver\Responses\CompletionResponse($data); - LlmDriverFacade::shouldReceive('completion') + LlmDriverFacade::shouldReceive('driver->completion') ->once() ->andReturn($dto);