Skip to content

Commit

Permalink
this is a big change the the core query for the better
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed May 7, 2024
1 parent a7cc85b commit fa6a8af
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 52 deletions.
16 changes: 8 additions & 8 deletions Modules/LlmDriver/app/DistanceQuery.php
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<?php
<?php

namespace LlmLaraHub\LlmDriver;

Expand All @@ -12,6 +12,7 @@
class DistanceQuery
{
protected int $distanceThreshold = 0;

/**
* @TODO
* Track the document page for referehce
Expand All @@ -27,7 +28,7 @@ public function distance(
->select('id')
->where('documents.collection_id', $collectionId)
->pluck('id');

$commonQuery = DocumentChunk::query()
->whereIn('document_id', $documentIds);

Expand All @@ -36,7 +37,7 @@ public function distance(
->nearestNeighbors($embeddingSize, $embedding, Distance::L2)
->take(5)
->get();

// Get IDs of the nearest neighbors found 5
$nearestNeighborIds = $documentChunkResults->pluck('id')->toArray();
Log::info('[LaraChain] Nearest Neighbor IDs', [
Expand All @@ -48,8 +49,8 @@ public function distance(
->whereNotIn('document_chunks.id', $nearestNeighborIds)
->nearestNeighbors($embeddingSize, $embedding, Distance::InnerProduct)
->get();
// Find nearest neighbors using Cosine distance found 0

// Find nearest neighbors using Cosine distance found 0
$neighborsInnerProductIds = $neighborsInnerProduct->pluck('id')->toArray();

Log::info('[LaraChain] Nearest Neighbor Inner Product IDs', [
Expand All @@ -59,7 +60,7 @@ public function distance(

$neighborsCosine = $commonQuery
->whereNotIn('id', $nearestNeighborIds)
->when(!empty($neighborsInnerProductIds), function($query) use ($neighborsInnerProductIds) {
->when(! empty($neighborsInnerProductIds), function ($query) use ($neighborsInnerProductIds) {
return $query->whereNotIn('id', $neighborsInnerProductIds);
})
->nearestNeighbors($embeddingSize, $embedding, Distance::Cosine)
Expand All @@ -74,5 +75,4 @@ public function distance(

return $results;
}

}
}
2 changes: 1 addition & 1 deletion Modules/LlmDriver/app/Functions/SearchAndSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
use App\Domains\Agents\VerifyPromptInputDto;
use App\Domains\Messages\RoleEnum;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
Expand Down
2 changes: 1 addition & 1 deletion Modules/LlmDriver/app/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public function handle(array $messagesArray, Chat $chat): ?string

Log::info("['LaraChain'] Functions Found?", [
'count' => count($functions),
'functions' => $functions
'functions' => $functions,
]);

if ($this->hasFunctions($functions)) {
Expand Down
22 changes: 7 additions & 15 deletions Modules/LlmDriver/tests/Feature/DistanceQueryTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,34 @@

namespace Tests\Feature;

use App\Models\Collection;
use App\Models\Document;
use App\Models\DocumentChunk;
use Illuminate\Support\Facades\File;
use LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
use OpenAI\Laravel\Facades\OpenAI;
use OpenAI\Responses\Chat\CreateResponse as ChatCreateResponse;
use OpenAI\Responses\Embeddings\CreateResponse;
use Pgvector\Laravel\Vector;
use Tests\TestCase;

class DistanceQueryTest extends TestCase
{


public function test_results() {
public function test_results()
{
$files = File::files(base_path('tests/fixtures/document_chunks'));
$document = Document::factory()->create([
'id' => 31
'id' => 31,
]);

foreach($files as $file) {
foreach ($files as $file) {
$data = json_decode(File::get($file), true);
DocumentChunk::factory()->create($data);
}

$question = get_fixture("embedding_question_distance.json");
$question = get_fixture('embedding_question_distance.json');

$vector = new Vector($question);

$results = (new DistanceQuery())->distance(
'embedding_1024',
$document->collection_id,
'embedding_1024',
$document->collection_id,
$vector);

$this->assertCount(1, $results);
Expand Down
2 changes: 1 addition & 1 deletion app/Domains/Messages/SearchAndSummarizeChatRepo.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
use App\Models\Chat;
use App\Models\DocumentChunk;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
Expand Down
3 changes: 2 additions & 1 deletion app/Http/Resources/MessageDocumentReferenceResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ class MessageDocumentReferenceResource extends JsonResource
public function toArray(Request $request): array
{
$tags = $this->document_chunk?->tags;
if($tags) {
if ($tags) {
$tags = TagResource::collection($tags);
}

return [
'id' => $this->id,
'document_name' => $this->document_chunk?->document->file_path,
Expand Down
37 changes: 17 additions & 20 deletions app/helpers.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,30 @@ function put_fixture($file_name, $content = [], $json = true)
}
}


if (! function_exists('calculate_dynamic_threshold')) {
function calculate_dynamic_threshold(array $distances, int $percentile = 90): float
{
// Sort the distances in ascending order
sort($distances);

// Get the total number of distances
$count = count($distances);

// Calculate the index for the given percentile
$index = ceil($count * ($percentile / 100)) - 1;

// Ensure the index is within the bounds of the array
if ($index >= $count) {
$index = $count - 1;
}
// Sort the distances in ascending order
sort($distances);

// Calculate the average of distances up to the percentile index
$threshold = array_sum(array_slice($distances, 0, $index + 1)) / ($index + 1);
// Get the total number of distances
$count = count($distances);

return $threshold;

}
}
// Calculate the index for the given percentile
$index = ceil($count * ($percentile / 100)) - 1;

// Ensure the index is within the bounds of the array
if ($index >= $count) {
$index = $count - 1;
}

// Calculate the average of distances up to the percentile index
$threshold = array_sum(array_slice($distances, 0, $index + 1)) / ($index + 1);

return $threshold;

}
}

if (! function_exists('chunk_string')) {
function chunk_string(string $string, int $maxTokenSize): array
Expand Down
5 changes: 3 additions & 2 deletions tests/Feature/HelpersTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ public function test_get_embedding_size(): void

$embedding_column = get_embedding_size($model->getEmbeddingDriver());

$this->assertEquals('embedding_3072', $embedding_column);
$this->assertEquals('embedding_3072', $embedding_column);

}

public function test_calculate_dynamic_threshold() {
public function test_calculate_dynamic_threshold()
{
$distances = [0.5, 0.75, 2, 3, 10];

$threshold = calculate_dynamic_threshold($distances);
Expand Down
3 changes: 0 additions & 3 deletions tests/Feature/SearchAndSummarizeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
use App\Models\Document;
use App\Models\DocumentChunk;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Illuminate\Support\Facades\File;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\Functions\ParametersDto;
use LlmLaraHub\LlmDriver\Functions\PropertyDto;
Expand All @@ -35,8 +34,6 @@ public function test_can_generate_function_as_array(): void
$this->assertInstanceOf(PropertyDto::class, $parameters->properties[0]);
}



public function test_gets_user_input()
{
$messageArray = [];
Expand Down

0 comments on commit fa6a8af

Please sign in to comment.