Skip to content

Commit

Permalink
Merge pull request #13 from LlmLaraHub/chunk_the_chunks
Browse files Browse the repository at this point in the history
Chunk the chunks
  • Loading branch information
alnutile authored May 11, 2024
2 parents 0197e6a + 17b0466 commit c3eb4c8
Show file tree
Hide file tree
Showing 82 changed files with 2,278 additions and 549 deletions.
2 changes: 1 addition & 1 deletion Modules/LlmDriver/app/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function ($item) {
)->implode('\n');

$systemPrompt = <<<EOD
You are a helpful assistant in a RAG system with tools and functions to help perform tasks.
You are a helpful assistant in a Retrieval augmented generation system (RAG - an architectural approach that can improve the efficacy of large language model (LLM) applications by leveraging custom data) system with tools and functions to help perform tasks.
When you find the right function make sure to return just the JSON that represents the requirements of that function.
If no function is found just return {} empty json
Expand Down
81 changes: 46 additions & 35 deletions Modules/LlmDriver/app/DistanceQuery.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ class DistanceQuery
protected int $distanceThreshold = 0;

/**
* @NOTES
* Some of the reasoning:
* Cosine Similarity: Cosine similarity is often considered one of the most effective metrics for measuring similarity between documents, especially when dealing with high-dimensional data like text documents. It's robust to differences in document length and is effective at capturing semantic similarity.
* Inner Product: Inner product similarity is another metric that can be effective, particularly for certain types of data. It measures the alignment between vectors, which can be useful in contexts where the direction of the vectors is important.
* L2 (Euclidean) Distance: L2 distance is a straightforward metric that measures the straight-line distance between vectors. While it's commonly used and easy to understand, it may not always be the most effective for capturing complex relationships between documents, especially in high-dimensional spaces.
*
* @TODO
* Track the document page for referehce
* I save distance should I save cosine and inner_product
*
* @see https://github.com/orgs/LlmLaraHub/projects/1?pane=issue&itemId=60394288
*/
Expand All @@ -24,45 +30,19 @@ public function distance(
int $collectionId,
Vector $embedding
): Collection {

$documentIds = Document::query()
->select('id')
->where('documents.collection_id', $collectionId)
->orderBy('id')
->pluck('id');

$commonQuery = DocumentChunk::query()
->orderBy('sort_order')
->orderBy('section_number')
->whereIn('document_id', $documentIds);

// Find nearest neighbors using L2 distance
$documentChunkResults = $commonQuery
->nearestNeighbors($embeddingSize, $embedding, Distance::L2)
->take(5)
->get();

// Get IDs of the nearest neighbors found 5
$nearestNeighborIds = $documentChunkResults->pluck('id')->toArray();
Log::info('[LaraChain] Nearest Neighbor IDs', [
'count' => count($nearestNeighborIds),
'ids' => $nearestNeighborIds,
]);
// Find nearest neighbors using InnerProduct distance
$neighborsInnerProduct = $commonQuery
->whereNotIn('document_chunks.id', $nearestNeighborIds)
->nearestNeighbors($embeddingSize, $embedding, Distance::InnerProduct)
->get();

// Find nearest neighbors using Cosine distance found 0
$neighborsInnerProductIds = $neighborsInnerProduct->pluck('id')->toArray();

Log::info('[LaraChain] Nearest Neighbor Inner Product IDs', [
'count' => count($neighborsInnerProductIds),
'ids' => $neighborsInnerProductIds,
]);

$neighborsCosine = $commonQuery
->whereNotIn('id', $nearestNeighborIds)
->when(! empty($neighborsInnerProductIds), function ($query) use ($neighborsInnerProductIds) {
return $query->whereNotIn('id', $neighborsInnerProductIds);
})
->nearestNeighbors($embeddingSize, $embedding, Distance::Cosine)
->get();

Expand All @@ -72,11 +52,42 @@ public function distance(
]);

$results = collect($neighborsCosine)
->merge($neighborsInnerProduct)
->merge($documentChunkResults)
->unique('id')
->take(10);
->take(8);

$siblingsIncluded = collect();

foreach ($results as $result) {
if ($result->section_number === 0) {
$siblingsIncluded->push($result);
} else {
if ($sibling = $this->getSiblingOrNot($result, $result->section_number - 1)) {
$siblingsIncluded->push($sibling);
}

$siblingsIncluded->push($result);
}

if ($sibling = $this->getSiblingOrNot($result, $result->section_number + 1)) {
$siblingsIncluded->push($sibling);
}
}

return $siblingsIncluded;
}

protected function getSiblingOrNot(DocumentChunk $result, int $sectionNumber): false|DocumentChunk
{
$sibling = DocumentChunk::query()
->where('document_id', $result->document_id)
->where('sort_order', $result->sort_order)
->where('section_number', $sectionNumber)
->first();

if ($sibling?->id) {
return $sibling;
}

return $results;
return false;
}
}
71 changes: 42 additions & 29 deletions Modules/LlmDriver/app/Functions/SearchAndSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
namespace LlmLaraHub\LlmDriver\Functions;

use App\Domains\Agents\VerifyPromptInputDto;
use App\Domains\Agents\VerifyPromptOutputDto;
use App\Domains\Messages\RoleEnum;
use App\Domains\Prompts\SummarizePrompt;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
Expand All @@ -24,14 +26,16 @@ class SearchAndSummarize extends FunctionContract

protected string $description = 'Used to embed users prompt, search database and return summarized results.';

protected string $response = '';

/**
* @param MessageInDto[] $messageArray
*/
public function handle(
array $messageArray,
HasDrivers $model,
FunctionCallDto $functionCallDto): FunctionResponse
{
FunctionCallDto $functionCallDto
): FunctionResponse {
Log::info('[LaraChain] Using Function: SearchAndSummarize');

/**
Expand Down Expand Up @@ -80,20 +84,10 @@ public function handle(

$context = implode(' ', $content);

$contentFlattened = <<<PROMPT
You are a helpful assistant in the RAG system:
This is data from the search results when entering the users prompt which is
### START PROMPT
{$originalPrompt}
### END PROMPT
Please use this with the following context and only this, summarize it for the user and return as markdown so I can render it and strip out and formatting like extra spaces, tabs, periods etc:
### START Context
$context
### END Context
PROMPT;
$contentFlattened = SummarizePrompt::prompt(
originalPrompt: $originalPrompt,
context: $context
);

$model->getChat()->addInput(
message: $contentFlattened,
Expand All @@ -102,7 +96,10 @@ public function handle(
show_in_thread: false
);

Log::info('[LaraChain] Getting the Summary from the search results');
Log::info('[LaraChain] Getting the Search and Summary results', [
'input' => $contentFlattened,
'driver' => $model->getChat()->chatable->getDriver(),
]);

$messageArray = MessageInDto::from([
'content' => $contentFlattened,
Expand All @@ -114,8 +111,33 @@ public function handle(
/** @var CompletionResponse $response */
$response = LlmDriverFacade::driver(
$model->getChatable()->getDriver()
)->chat([$messageArray]);
)->completion($contentFlattened);

$this->response = $response->content;

if (Feature::active('verification_prompt')) {
$this->verify($model, $originalPrompt, $context);
}

$message = $model->getChat()->addInput($this->response, RoleEnum::Assistant);

$this->saveDocumentReference($message, $documentChunkResults);

notify_ui($model->getChat(), 'Complete');

return FunctionResponse::from(
[
'content' => $this->response,
'save_to_message' => false,
]
);
}

protected function verify(
HasDrivers $model,
string $originalPrompt,
string $context
): void {
/**
* Lets Verify
*/
Expand All @@ -129,7 +151,7 @@ public function handle(
'chattable' => $model->getChat(),
'originalPrompt' => $originalPrompt,
'context' => $context,
'llmResponse' => $response->content,
'llmResponse' => $this->response,
'verifyPrompt' => $verifyPrompt,
]
);
Expand All @@ -139,16 +161,7 @@ public function handle(
/** @var VerifyPromptOutputDto $response */
$response = VerifyResponseAgent::verify($dto);

$message = $model->getChat()->addInput($response->response, RoleEnum::Assistant);

$this->saveDocumentReference($message, $documentChunkResults);

return FunctionResponse::from(
[
'content' => $response->response,
'save_to_message' => false,
]
);
$this->response = $response->response;
}

/**
Expand Down
62 changes: 41 additions & 21 deletions Modules/LlmDriver/app/Functions/SummarizeCollection.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

use App\Domains\Agents\VerifyPromptInputDto;
use App\Domains\Agents\VerifyPromptOutputDto;
use App\Domains\Chat\UiStatusEnum;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Prompts\SummarizeCollectionPrompt;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\FunctionResponse;

Expand All @@ -17,6 +20,8 @@ class SummarizeCollection extends FunctionContract

protected string $description = 'NOT FOR SEARCH, This is used when the prompt wants to summarize the entire collection of documents';

protected string $response = '';

public function handle(
array $messageArray,
HasDrivers $model,
Expand All @@ -32,23 +37,15 @@ public function handle(
*/
foreach ($model->chatable->documents as $document) {
foreach ($document->document_chunks as $chunk) {
$summary->add($chunk->summary);
$summary->add($chunk->content);
}
}

notify_ui($model->getChat(), 'Getting Summary');

$summary = $summary->implode('\n');

$prompt = <<<PROMPT
Can you summarize all of this content for me from a collection of documents I uploaded what
follows is the content:
### START ALL SUMMARY DATA
$summary
### END ALL SUMMARY DATA
PROMPT;
$prompt = SummarizeCollectionPrompt::prompt($summary);

$messagesArray = [];

Expand All @@ -59,30 +56,53 @@ public function handle(

$results = LlmDriverFacade::driver($model->getDriver())->chat($messagesArray);

notify_ui($model->getChat(), 'Summary complete going to do one verfication check on the summarhy');
$this->response = $results->content;

notify_ui($model->getChat(), 'Summary complete');

if (Feature::active('verification_prompt')) {
Log::info('[LaraChain] Verifying Summary Collection');
$this->verify($model, 'Can you summarize this collection of data for me.', $summary);
}

notify_ui($model->getChat(), UiStatusEnum::Complete->name);

return FunctionResponse::from([
'content' => $this->response,
'prompt' => $prompt,
'requires_followup' => true,
]);
}

protected function verify(
HasDrivers $model,
string $originalPrompt,
string $context
): void {
/**
* Lets Verify
*/
$verifyPrompt = <<<'PROMPT'
This the content from all the documents in this collection.
Then that was passed into the LLM to summarize the results.
PROMPT;
This the content from all the documents in this collection.
Then that was passed into the LLM to summarize the results.
PROMPT;

$dto = VerifyPromptInputDto::from(
[
'chattable' => $model->getChat(),
'originalPrompt' => 'Can you summarize this collection of data for me.',
'context' => $summary,
'llmResponse' => $results->content,
'originalPrompt' => $originalPrompt,
'context' => $context,
'llmResponse' => $this->response,
'verifyPrompt' => $verifyPrompt,
]
);

notify_ui($model, 'Verifiying Results');

/** @var VerifyPromptOutputDto $response */
$response = VerifyResponseAgent::verify($dto);

return FunctionResponse::from([
'content' => $response->response,
'requires_followup' => true,
]);
$this->response = $response->response;
}

/**
Expand Down
Loading

0 comments on commit c3eb4c8

Please sign in to comment.