Skip to content

Commit

Permalink
Merge pull request #4 from LlmLaraHub/embeddings_different_systems
Browse files Browse the repository at this point in the history
Embeddings different systems
  • Loading branch information
alnutile authored Apr 9, 2024
2 parents 4643963 + 8e18d98 commit 8391507
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 30 deletions.
4 changes: 3 additions & 1 deletion app/Domains/Messages/SearchOrSummarizeChatRepo.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ public function search(Chat $chat, string $input): string
$chat->chatable->getEmbeddingDriver()
)->embedData($input);

$embeddingSize = get_embedding_size($chat->chatable->getEmbeddingDriver());

$results = DocumentChunk::query()
->join('documents', 'documents.id', '=', 'document_chunks.document_id')
->selectRaw(
'document_chunks.embedding <-> ? as distance, document_chunks.content, document_chunks.embedding as embedding, document_chunks.id as id',
"document_chunks.{$embeddingSize} <-> ? as distance, document_chunks.content, document_chunks.{$embeddingSize} as embedding, document_chunks.id as id",
[$embedding->embedding]
)
->where('documents.collection_id', $chat->chatable->id)
Expand Down
4 changes: 3 additions & 1 deletion app/Jobs/VectorlizeDataJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ public function handle(): void
$results = LlmDriverFacade::driver($this->documentChunk->getEmbeddingDriver())
->embedData($content);

$embedding_column = $this->documentChunk->getEmbeddingColumn();

$this->documentChunk->update([
'embedding' => $results->embedding,
$embedding_column => $results->embedding,
'status_embeddings' => StatusEnum::Complete,
]);
}
Expand Down
8 changes: 4 additions & 4 deletions app/LlmDriver/OpenAiClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public function chat(array $messages): CompletionResponse
{

$response = OpenAI::chat()->create([
'model' => $this->getConfig('openai')['chat_model'],
'model' => $this->getConfig('openai')['models']['chat_model'],
'messages' => collect($messages)->map(function ($message) {
return $message->toArray();
})->toArray(),
Expand All @@ -38,7 +38,7 @@ public function embedData(string $data): EmbeddingsResponseDto
{

$response = OpenAI::embeddings()->create([
'model' => $this->getConfig('openai')['embedding_model'],
'model' => $this->getConfig('openai')['models']['embedding_model'],
'input' => $data,
]);

Expand All @@ -57,7 +57,7 @@ public function embedData(string $data): EmbeddingsResponseDto
public function completion(string $prompt, int $temperature = 0): CompletionResponse
{
$response = OpenAI::chat()->create([
'model' => $this->getConfig('openai')['completion_model'],
'model' => $this->getConfig('openai')['models']['completion_model'],
'messages' => [
['role' => 'user', 'content' => $prompt],
],
Expand Down Expand Up @@ -87,7 +87,7 @@ public function functionPromptChat(array $messages, array $only = []): array
$functions = $this->getFunctions();

$response = OpenAI::chat()->create([
'model' => $this->getConfig('openai')['chat_model'],
'model' => $this->getConfig('openai')['models']['chat_model'],
'messages' => collect($messages)->map(function ($message) {
return $message->toArray();
})->toArray(),
Expand Down
2 changes: 0 additions & 2 deletions app/LlmDriver/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ public function handle(array $messagesArray, Chat $chat): ?string
$functions = LlmDriverFacade::driver($chat->chatable->getDriver())
->functionPromptChat($messagesArray);

put_fixture('orchestrate_functions_ollama.json', $functions);

if ($this->hasFunctions($functions)) {
/**
* @TODO
Expand Down
12 changes: 11 additions & 1 deletion app/Models/DocumentChunk.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class DocumentChunk extends Model implements HasDrivers
use HasFactory;

protected $casts = [
'embedding' => Vector::class,
'embedding_3072' => Vector::class,
'embedding_1536' => Vector::class,
'embedding_2048' => Vector::class,
'embedding_4096' => Vector::class,
'status_embeddings' => StatusEnum::class,
'status_tagging' => StatusEnum::class,
'status_summary' => StatusEnum::class,
Expand Down Expand Up @@ -51,4 +54,11 @@ public function getDriver(): string
{
return $this->document->collection->driver->value;
}

public function getEmbeddingColumn(): string
{

return get_embedding_size($this->getEmbeddingDriver());

}
}
25 changes: 25 additions & 0 deletions app/helpers.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use App\LlmDriver\Helpers\TrimText;
use Illuminate\Support\Facades\File;
use Illuminate\Support\Facades\Log;

if (! function_exists('put_fixture')) {
function put_fixture($file_name, $content = [], $json = true)
Expand Down Expand Up @@ -51,3 +52,27 @@ function reduce_text_size(string $text): string
return (new TrimText())->handle($text);
}
}

if (! function_exists('driverHelper')) {
function driverHelper(string $driver, string $key): string
{
Log::info("driverHelper: {$driver} {$key}");

return config("llmdriver.drivers.{$driver}.{$key}");
}
}

if (! function_exists('get_embedding_size')) {
function get_embedding_size(string $ebmedding_driver): string
{
$embeddingModel = driverHelper($ebmedding_driver, 'models.embedding_model');

$size = config('llmdriver.embedding_sizes.'.$embeddingModel);

if ($size) {
return 'embedding_'.$size;
}

return 'embeding_3072';
}
}
26 changes: 19 additions & 7 deletions config/llmdriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,31 @@
return [
'driver' => env('LLM_DRIVER', 'mock'),

'embedding_sizes' => [
'mock' => 4096,
'text-embedding-3-large' => 3072,
'text-embedding-3-medium' => 768,
'text-embedding-3-small' => 384,
'ollama' => 4096,
'llama2' => 4096,
'mistral' => 4096,
],

'drivers' => [
'mock' => [

'models' => [
'completion_model' => 'mock',
'embedding_model' => 'mock',
],
],
'openai' => [
'api_key' => env('OPENAI_API_KEY'),
'api_url' => env('OPENAI_API_URL', 'https://api.openai.com/v1'),
'embedding_model' => env('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-large'),
'completion_model' => env('OPENAI_COMPLETION_MODEL', 'gpt-4-turbo-preview'),
'chat_model' => env('OPENAICHAT_MODEL', 'gpt-4-turbo-preview'),
],
'mock' => [

'models' => [
'embedding_model' => env('OPENAI_EMBEDDING_MODEL', 'text-embedding-3-large'),
'completion_model' => env('OPENAI_COMPLETION_MODEL', 'gpt-4-turbo-preview'),
'chat_model' => env('OPENAICHAT_MODEL', 'gpt-4-turbo-preview'),
],
],
'claude' => [
'api_key' => env('CLAUDE_API_KEY'),
Expand Down
25 changes: 24 additions & 1 deletion database/factories/DocumentChunkFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
namespace Database\Factories;

use App\Domains\Documents\StatusEnum;
use App\LlmDriver\DriversEnum;
use App\Models\Collection;
use App\Models\Document;
use Illuminate\Database\Eloquent\Factories\Factory;

Expand All @@ -29,7 +31,28 @@ public function definition(): array
'original_content' => fake()->sentence(10),
'summary' => fake()->sentence(5),
'document_id' => Document::factory(),
'embedding' => data_get($embeddings, 'data.0.embedding'),
'embedding_3072' => data_get($embeddings, 'data.0.embedding'),
'embedding_1536' => null,
'embedding_2048' => null,
'embedding_4096' => null,
];
}

public function openAi(): Factory
{

return $this->state(function (array $attributes) {
$collection = Collection::factory()->create([
'driver' => DriversEnum::OpenAi,
'embedding_driver' => DriversEnum::OpenAi,
]);
$document = Document::factory()->create([
'collection_id' => $collection->id,
]);

return [
'document_id' => $document->id,
];
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<?php

use Illuminate\Database\Migrations\Migration;
use Illuminate\Database\Schema\Blueprint;
use Illuminate\Support\Facades\Schema;

return new class extends Migration
{
/**
* Run the migrations.
*/
public function up(): void
{
Schema::table('document_chunks', function (Blueprint $table) {
$table->vector('embedding_1536', 1536)->nullable();
$table->vector('embedding_2048', 2048)->nullable();
$table->vector('embedding_3072', 3072)->nullable();
$table->vector('embedding_4096', 4096)->nullable();
$table->dropColumn('embedding');
});
}

/**
* Reverse the migrations.
*/
public function down(): void
{
Schema::table('document_chunks', function (Blueprint $table) {
//
});
}
};
2 changes: 1 addition & 1 deletion resources/js/Pages/Collection/Show.vue
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ onMounted(() => {
});
const reset = () => {
router.reload();
//router.reload();
}
</script>
Expand Down
38 changes: 38 additions & 0 deletions tests/Feature/HelpersTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<?php

namespace Tests\Feature;

use App\Models\DocumentChunk;
use Tests\TestCase;

class HelpersTest extends TestCase
{
/**
* A basic feature test example.
*/
public function test_config_helper(): void
{

$this->assertEquals('mock', driverHelper('mock', 'models.embedding_model'));

}

public function test_get_embedding_size(): void
{

$model = DocumentChunk::factory()->create();

$embedding_column = get_embedding_size($model->getEmbeddingDriver());

$this->assertEquals('embedding_4096', $embedding_column);

$model = DocumentChunk::factory()
->openAi()
->create();

$embedding_column = get_embedding_size($model->getEmbeddingDriver());

$this->assertEquals('embedding_3072', $embedding_column);

}
}
8 changes: 4 additions & 4 deletions tests/Feature/Jobs/VectorlizeDataJobTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ public function test_gets_data(): void
->once()
->andReturn($dto);

$documentChunk = DocumentChunk::factory()->create([
'embedding' => null,
]);
$documentChunk = DocumentChunk::factory()
->openAi()
->create();

$job = new VectorlizeDataJob($documentChunk);
$job->handle();

$this->assertNotEmpty($documentChunk->embedding);
$this->assertNotEmpty($documentChunk->embedding_3072);
}
}
19 changes: 19 additions & 0 deletions tests/Feature/Models/DocumentChunkTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,23 @@ public function test_dc_rel()
$this->assertCount(1, $model->document->document_chunks);
$this->assertNotNull($model->document->document_chunks()->first()->id);
}

public function test_embedding_dynamic()
{

$model = DocumentChunk::factory()->create();

$embedding_column = $model->getEmbeddingColumn();

$this->assertEquals('embedding_4096', $embedding_column);

$model = DocumentChunk::factory()
->openAi()
->create();

$embedding_column = $model->getEmbeddingColumn();

$this->assertEquals('embedding_3072', $embedding_column);

}
}
9 changes: 1 addition & 8 deletions tests/fixtures/orchestrate_functions_ollama.json
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
[
{
"name": "summarize_collection",
"arguments": [
"TLDR it for me"
]
}
]
[]

0 comments on commit 8391507

Please sign in to comment.