Skip to content

Commit

Permalink
driver working
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Mar 27, 2024
1 parent 6a8b130 commit e53dcaa
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 43 deletions.
4 changes: 3 additions & 1 deletion app/Jobs/SummarizeDataJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ public function handle(): void
EOD;

/** @var CompletionResponse $results */
$results = LlmDriverFacade::completion($prompt);
$results = LlmDriverFacade::driver(
$this->documentChunk->getDriver()
)->completion($prompt);

$this->documentChunk->update([
'summary' => $results->content,
Expand Down
4 changes: 3 additions & 1 deletion app/Jobs/SummarizeDocumentJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ public function handle(): void
EOD;

/** @var CompletionResponse $results */
$results = LlmDriverFacade::completion($prompt);
$results = LlmDriverFacade::driver(
$this->document->getDriver()
)->completion($prompt);

$this->document->update([
'summary' => $results->content,
Expand Down
5 changes: 3 additions & 2 deletions app/Jobs/VectorlizeDataJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public function __construct(public DocumentChunk $documentChunk)
public function handle(): void
{

if ($this->batch()->cancelled()) {
if (optional($this->batch())->cancelled()) {
// Determine if the batch has been cancelled...
$this->documentChunk->update([
'status_embeddings' => StatusEnum::Cancelled,
Expand All @@ -43,7 +43,8 @@ public function handle(): void
$content = $this->documentChunk->content;

/** @var EmbeddingsResponseDto $results */
$results = LlmDriverFacade::embedData($content);
$results = LlmDriverFacade::driver($this->documentChunk->getDriver())
->embedData($content);

$this->documentChunk->update([
'embedding' => $results->embedding,
Expand Down
37 changes: 23 additions & 14 deletions app/LlmDriver/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,38 @@

class LlmDriverClient
{
protected $config = [];
protected $drivers = [];

public function __construct(array $config)
public function driver($name = null)
{
$this->config = $config;
}
$name = $name ?: $this->getDefaultDriver();

public static function make(): BaseClient
{
$driver = config('llmdriver.driver');
if (! isset($this->drivers[$name])) {
$this->drivers[$name] = $this->createDriver($name);
}

$config = config("llmdriver.drivers.{$driver}");
return $this->drivers[$name];
}

if (! method_exists(static::class, $driver)) {
throw new \Exception("Driver {$driver} not found");
protected function createDriver($name)
{
switch ($name) {
case 'openai':
return new OpenAiClient();
case 'mock':
return new MockClient();
default:
throw new \InvalidArgumentException("Driver [{$name}] is not supported.");
}
}

/** @phpstan-ignore-next-line */
return (new static($config))->$driver();
public static function getDrivers(): array
{
return array_keys(config('llmdriver.drivers'));
}

public function mock(): BaseClient
protected function getDefaultDriver()
{
return new MockClient();
return 'mock';
}
}
4 changes: 0 additions & 4 deletions app/LlmDriver/MockClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

namespace App\LlmDriver;

use App\LlmDriver\Responses\CompletionResponse;
use Illuminate\Support\Facades\Log;

class MockClient extends BaseClient
{

}
17 changes: 7 additions & 10 deletions app/LlmDriver/OpenAiClient.php
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
<?php
<?php

namespace App\LlmDriver;

use App\LlmDriver\Responses\CompletionResponse;
use App\LlmDriver\Responses\EmbeddingsResponseDto;
use Illuminate\Support\Facades\Log;
use OpenAI\Laravel\Facades\OpenAI;

class OpenAiClient extends BaseClient
Expand All @@ -22,9 +21,7 @@ public function embedData(string $data): EmbeddingsResponseDto
$results = [];

foreach ($response->embeddings as $embedding) {
$embedding->object; // 'embedding'
$results = $embedding->embedding; // [0.018990106880664825, -0.0073809814639389515, ...]
$embedding->index; // 0
}

return new EmbeddingsResponseDto(
Expand All @@ -35,19 +32,19 @@ public function embedData(string $data): EmbeddingsResponseDto

public function completion(string $prompt, int $temperature = 0): CompletionResponse
{
$response = OpenAI::completions()->create([
$response = OpenAI::chat()->create([
'model' => $this->getConfig('openai')['completion_model'],
'prompt' => $prompt,
'temperature' => 0
'messages' => [
['role' => 'user', 'content' => $prompt],
],
]);

$results = null;

foreach ($response->choices as $result) {
$results = $result->text; // '\n\nThis is a test'
$results = $result->message->content;
}

return new CompletionResponse($results);
}

}
}
5 changes: 5 additions & 0 deletions app/Models/Document.php
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,9 @@ public function mkdirPathToFile(): ?string
$this->collection_id
);
}

public function getDriver(): string
{
return $this->collection->driver;
}
}
9 changes: 9 additions & 0 deletions app/Models/DocumentChunk.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
use Illuminate\Database\Eloquent\Model;
use Pgvector\Laravel\Vector;

/**
* @property Document $document
* @package App\Models
*/
class DocumentChunk extends Model
{
use HasFactory;
Expand Down Expand Up @@ -37,4 +41,9 @@ protected static function booted()
$document_chunk->saveQuietly();
});
}

public function getDriver(): string
{
return $this->document->collection->driver;
}
}
2 changes: 1 addition & 1 deletion app/Providers/AppServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public function boot(): void
});

$this->app->bind('llm_driver', function () {
return LlmDriverClient::make();
return new LlmDriverClient();
});

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?php

use Illuminate\Database\Migrations\Migration;
use Illuminate\Database\Schema\Blueprint;
use Illuminate\Support\Facades\Schema;

return new class extends Migration
{
/**
* Run the migrations.
*/
public function up(): void
{
Schema::table('collections', function (Blueprint $table) {
$table->string('driver')->default('mock');
});
}

/**
* Reverse the migrations.
*/
public function down(): void
{
Schema::table('collections', function (Blueprint $table) {
$table->dropColumn('driver');
});
}
};
2 changes: 1 addition & 1 deletion tests/Feature/Jobs/VectorlizeDataJobTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public function test_gets_data(): void
1000
);

LlmDriverFacade::shouldReceive('embedData')
LlmDriverFacade::shouldReceive('driver->embedData')
->once()
->andReturn($dto);

Expand Down
28 changes: 28 additions & 0 deletions tests/Feature/LlmDriverClientTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?php

namespace Tests\Feature;

use App\LlmDriver\LlmDriverFacade;
use App\LlmDriver\MockClient;
use App\LlmDriver\OpenAiClient;
use Tests\TestCase;

class LlmDriverClientTest extends TestCase
{
/**
* A basic feature test example.
*/
public function test_driver(): void
{
$results = LlmDriverFacade::driver('mock');

$this->assertInstanceOf(MockClient::class, $results);
}

public function test_driver_openai(): void
{
$results = LlmDriverFacade::driver('openai');

$this->assertInstanceOf(OpenAiClient::class, $results);
}
}
2 changes: 1 addition & 1 deletion tests/Feature/LlmDriverFacadeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LlmDriverFacadeTest extends TestCase
*/
public function test_facade(): void
{
$results = LlmDriverFacade::embedData('test');
$results = LlmDriverFacade::driver('mock')->embedData('test');

$this->assertInstanceOf(
EmbeddingsResponseDto::class,
Expand Down
13 changes: 6 additions & 7 deletions tests/Feature/OpenAiClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

namespace Tests\Feature;

use Illuminate\Foundation\Testing\RefreshDatabase;
use Illuminate\Foundation\Testing\WithFaker;
use Tests\TestCase;

use App\LlmDriver\Responses\CompletionResponse;
use App\LlmDriver\Responses\EmbeddingsResponseDto;
use Illuminate\Support\Facades\Log;
use OpenAI\Laravel\Facades\OpenAI;
use OpenAI\Responses\Chat\CreateResponse as ChatCreateResponse;
use OpenAI\Responses\Completions\CreateResponse as CompletionsCreateResponse;
use OpenAI\Responses\Embeddings\CreateResponse;
use Tests\TestCase;

class OpenAiClientTest extends TestCase
{
Expand All @@ -38,10 +35,12 @@ public function test_openai_client(): void
public function test_completion(): void
{
OpenAI::fake([
CompletionsCreateResponse::fake([
ChatCreateResponse::fake([
'choices' => [
[
'choice' => 'awesome!',
'message' => [
'content' => 'awesome!'
],
],
],
]),
Expand Down
2 changes: 1 addition & 1 deletion tests/Feature/SummarizeDataJobTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public function test_gets_data(): void
$data = 'Foo bar';
$dto = new \App\LlmDriver\Responses\CompletionResponse($data);

LlmDriverFacade::shouldReceive('completion')
LlmDriverFacade::shouldReceive('driver->completion')
->once()
->andReturn($dto);

Expand Down

0 comments on commit e53dcaa

Please sign in to comment.