Skip to content

Commit

Permalink
Merge pull request #12 from LlmLaraHub/distance
Browse files Browse the repository at this point in the history
Distance Query Updates
  • Loading branch information
alnutile authored May 7, 2024
2 parents edf9c5d + fa6a8af commit ed206d8
Show file tree
Hide file tree
Showing 27 changed files with 1,547 additions and 52 deletions.
78 changes: 78 additions & 0 deletions Modules/LlmDriver/app/DistanceQuery.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
<?php

namespace LlmLaraHub\LlmDriver;

use App\Models\Document;
use App\Models\DocumentChunk;
use Illuminate\Support\Collection;
use Illuminate\Support\Facades\Log;
use Pgvector\Laravel\Distance;
use Pgvector\Laravel\Vector;

class DistanceQuery
{
protected int $distanceThreshold = 0;

/**
* @TODO
* Track the document page for referehce
*
* @see https://github.com/orgs/LlmLaraHub/projects/1?pane=issue&itemId=60394288
*/
public function distance(
string $embeddingSize,
int $collectionId,
Vector $embedding
): Collection {
$documentIds = Document::query()
->select('id')
->where('documents.collection_id', $collectionId)
->pluck('id');

$commonQuery = DocumentChunk::query()
->whereIn('document_id', $documentIds);

// Find nearest neighbors using L2 distance
$documentChunkResults = $commonQuery
->nearestNeighbors($embeddingSize, $embedding, Distance::L2)
->take(5)
->get();

// Get IDs of the nearest neighbors found 5
$nearestNeighborIds = $documentChunkResults->pluck('id')->toArray();
Log::info('[LaraChain] Nearest Neighbor IDs', [
'count' => count($nearestNeighborIds),
'ids' => $nearestNeighborIds,
]);
// Find nearest neighbors using InnerProduct distance
$neighborsInnerProduct = $commonQuery
->whereNotIn('document_chunks.id', $nearestNeighborIds)
->nearestNeighbors($embeddingSize, $embedding, Distance::InnerProduct)
->get();

// Find nearest neighbors using Cosine distance found 0
$neighborsInnerProductIds = $neighborsInnerProduct->pluck('id')->toArray();

Log::info('[LaraChain] Nearest Neighbor Inner Product IDs', [
'count' => count($neighborsInnerProductIds),
'ids' => $neighborsInnerProductIds,
]);

$neighborsCosine = $commonQuery
->whereNotIn('id', $nearestNeighborIds)
->when(! empty($neighborsInnerProductIds), function ($query) use ($neighborsInnerProductIds) {
return $query->whereNotIn('id', $neighborsInnerProductIds);
})
->nearestNeighbors($embeddingSize, $embedding, Distance::Cosine)
->get();

Log::info('[LaraChain] Nearest Neighbor Cosine IDs', [
'count' => $neighborsCosine->count(),
'ids' => $neighborsCosine->pluck('id')->toArray(),
]);

$results = collect($documentChunkResults)->merge($neighborsCosine)->unique('id')->take(5);

return $results;
}
}
6 changes: 3 additions & 3 deletions Modules/LlmDriver/app/Functions/SearchAndSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
use App\Domains\Agents\VerifyPromptInputDto;
use App\Domains\Messages\RoleEnum;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\Helpers\DistanceQueryTrait;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
Expand All @@ -18,7 +18,7 @@

class SearchAndSummarize extends FunctionContract
{
use CreateReferencesTrait, DistanceQueryTrait;
use CreateReferencesTrait;

protected string $name = 'search_and_summarize';

Expand Down Expand Up @@ -57,7 +57,7 @@ public function handle(

notify_ui($model, 'Searching documents');

$documentChunkResults = $this->distance(
$documentChunkResults = DistanceQuery::distance(
$embeddingSize,
$model->getChatable()->id,
$embedding->embedding
Expand Down
5 changes: 3 additions & 2 deletions Modules/LlmDriver/app/Helpers/CreateReferencesTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
namespace LlmLaraHub\LlmDriver\Helpers;

use App\Models\Message;
use Illuminate\Database\Eloquent\Collection;
use Illuminate\Support\Collection;

trait CreateReferencesTrait
{
protected function saveDocumentReference(
Message $model,
Collection $documentChunks
): void {
//put_fixture("document_chunks.json", $documentChunks->toArray());
//add each one to a batch job or do the work here.
foreach ($documentChunks as $documentChunk) {
$model->message_document_references()->create([
'document_chunk_id' => $documentChunk->id,
'distance' => $documentChunk->distance,
'distance' => $documentChunk->neighbor_distance,
]);
}
}
Expand Down
37 changes: 0 additions & 37 deletions Modules/LlmDriver/app/Helpers/DistanceQueryTrait.php

This file was deleted.

7 changes: 5 additions & 2 deletions Modules/LlmDriver/app/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ public function handle(array $messagesArray, Chat $chat): ?string
$functions = LlmDriverFacade::driver($chat->chatable->getDriver())
->functionPromptChat($messagesArray);

Log::info("['LaraChain'] Functions Found", $functions);
Log::info("['LaraChain'] Functions Found?", [
'count' => count($functions),
'functions' => $functions,
]);

if ($this->hasFunctions($functions)) {
Log::info('[LaraChain] Orchestration Has Functions', $functions);
Expand Down Expand Up @@ -115,7 +118,7 @@ public function handle(array $messagesArray, Chat $chat): ?string

return $this->response;
} else {
Log::info('[LaraChain] Orchestration No Functions Default SearchAnd Summarize');
Log::info('[LaraChain] Orchestration No Functions Default Search And Summarize');
/**
* @NOTE
* this assumes way too much
Expand Down
38 changes: 38 additions & 0 deletions Modules/LlmDriver/tests/Feature/DistanceQueryTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<?php

namespace Tests\Feature;

use App\Models\Document;
use App\Models\DocumentChunk;
use Illuminate\Support\Facades\File;
use LlmLaraHub\LlmDriver\DistanceQuery;
use Pgvector\Laravel\Vector;
use Tests\TestCase;

class DistanceQueryTest extends TestCase
{
public function test_results()
{
$files = File::files(base_path('tests/fixtures/document_chunks'));
$document = Document::factory()->create([
'id' => 31,
]);

foreach ($files as $file) {
$data = json_decode(File::get($file), true);
DocumentChunk::factory()->create($data);
}

$question = get_fixture('embedding_question_distance.json');

$vector = new Vector($question);

$results = (new DistanceQuery())->distance(
'embedding_1024',
$document->collection_id,
$vector);

$this->assertCount(1, $results);

}
}
1 change: 1 addition & 0 deletions app/Domains/Agents/VerifyResponseAgent.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public function verify(VerifyPromptInputDto $input): VerifyPromptOutputDto
Your Response will be that just cleaned up for chat.
DO NOT include text like "Here is the cleaned-up response" the user should not even know your step happened :)
Your repsonse will NOT be a list like below but just follow the formatting of the "LLM RESPONSE".
DO NOT get an information outside of this context.
### Included are the following sections
- ORIGINAL PROMPT: The question from the user
Expand Down
12 changes: 8 additions & 4 deletions app/Domains/Messages/SearchAndSummarizeChatRepo.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,38 @@
use App\Models\Chat;
use App\Models\DocumentChunk;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\Helpers\DistanceQueryTrait;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;

class SearchAndSummarizeChatRepo
{
use CreateReferencesTrait, DistanceQueryTrait;
use CreateReferencesTrait;

public function search(Chat $chat, string $input): string
{
Log::info('[LaraChain] Embedding and Searching');
Log::info('[LaraChain] Search and Summarize Default Function');

$originalPrompt = $input;

notify_ui($chat, 'Searching documents');

Log::info('[LaraChain] Embedding the Data', [
'question' => $input,
]);

/** @var EmbeddingsResponseDto $embedding */
$embedding = LlmDriverFacade::driver(
$chat->chatable->getEmbeddingDriver()
)->embedData($input);

$embeddingSize = get_embedding_size($chat->chatable->getEmbeddingDriver());

$documentChunkResults = $this->distance(
$documentChunkResults = DistanceQuery::distance(
$embeddingSize,
/** @phpstan-ignore-next-line */
$chat->getChatable()->id,
Expand Down
13 changes: 9 additions & 4 deletions app/Http/Resources/MessageDocumentReferenceResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ class MessageDocumentReferenceResource extends JsonResource
*/
public function toArray(Request $request): array
{
$tags = $this->document_chunk?->tags;
if ($tags) {
$tags = TagResource::collection($tags);
}

return [
'id' => $this->id,
'document_name' => $this->document_chunk->document->file_path,
'page' => $this->document_chunk->sort_order,
'document_name' => $this->document_chunk?->document->file_path,
'page' => $this->document_chunk?->sort_order,
'distance' => round($this->distance, 2),
'summary' => str($this->document_chunk->summary)->markdown(),
'taggings' => TagResource::collection($this->document_chunk->tags),
'summary' => str($this->document_chunk?->summary)->markdown(),
'taggings' => $tags,
];
}
}
2 changes: 2 additions & 0 deletions app/Models/DocumentChunk.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\TagFunction\Contracts\TaggableContract;
use LlmLaraHub\TagFunction\Helpers\Taggable;
use Pgvector\Laravel\HasNeighbors;
use Pgvector\Laravel\Vector;

/**
Expand All @@ -16,6 +17,7 @@
class DocumentChunk extends Model implements HasDrivers, TaggableContract
{
use HasFactory;
use HasNeighbors;
use Taggable;

protected $casts = [
Expand Down
25 changes: 25 additions & 0 deletions app/helpers.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ function put_fixture($file_name, $content = [], $json = true)
}
}

if (! function_exists('calculate_dynamic_threshold')) {
function calculate_dynamic_threshold(array $distances, int $percentile = 90): float
{
// Sort the distances in ascending order
sort($distances);

// Get the total number of distances
$count = count($distances);

// Calculate the index for the given percentile
$index = ceil($count * ($percentile / 100)) - 1;

// Ensure the index is within the bounds of the array
if ($index >= $count) {
$index = $count - 1;
}

// Calculate the average of distances up to the percentile index
$threshold = array_sum(array_slice($distances, 0, $index + 1)) / ($index + 1);

return $threshold;

}
}

if (! function_exists('chunk_string')) {
function chunk_string(string $string, int $maxTokenSize): array
{
Expand Down
9 changes: 9 additions & 0 deletions tests/Feature/HelpersTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,13 @@ public function test_get_embedding_size(): void
$this->assertEquals('embedding_3072', $embedding_column);

}

public function test_calculate_dynamic_threshold()
{
$distances = [0.5, 0.75, 2, 3, 10];

$threshold = calculate_dynamic_threshold($distances);

$this->assertEquals(3.25, $threshold);
}
}
5 changes: 5 additions & 0 deletions tests/Feature/SearchAndSummarizeChatRepoTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use App\Models\Document;
use App\Models\DocumentChunk;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use Tests\TestCase;

Expand Down Expand Up @@ -64,6 +65,10 @@ public function test_can_search(): void
'document_id' => $document->id,
]);

DistanceQuery::shouldReceive('distance')
->once()
->andReturn(DocumentChunk::all());

$results = (new SearchAndSummarizeChatRepo())->search($chat, 'Puppy');

$this->assertNotNull($results);
Expand Down
Loading

0 comments on commit ed206d8

Please sign in to comment.