Skip to content

Commit

Permalink
Step 1 of RFP reply in place, solutions pulled out of the RFP
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Jul 12, 2024
1 parent ad932a5 commit 8655ee3
Show file tree
Hide file tree
Showing 124 changed files with 7,740 additions and 151 deletions.
20 changes: 20 additions & 0 deletions Modules/LlmDriver/app/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ abstract class BaseClient

protected int $poolSize = 3;

protected bool $formatJson = false;

public function setFormatJson(): self
{
$this->formatJson = true;

return $this;
}

public function embedData(string $data): EmbeddingsResponseDto
{
if (! app()->environment('testing')) {
Expand All @@ -36,6 +45,17 @@ protected function messagesToArray(array $messages): array
})->toArray();
}

public function addJsonFormat(array $payload): array
{
if ($this->formatJson) {
$payload['response_format'] = [
'type' => 'json_object',
];
}

return $payload;
}

/**
* This is to get functions out of the llm
* if none are returned your system
Expand Down
42 changes: 28 additions & 14 deletions Modules/LlmDriver/app/ClaudeClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public function completion(string $prompt): CompletionResponse

Log::info('LlmDriver::Claude::completion');

$results = $this->getClient()->post('/messages', [
$payload = [
'model' => $model,
'max_tokens' => $maxTokens,
'messages' => [
Expand All @@ -88,9 +88,13 @@ public function completion(string $prompt): CompletionResponse
'content' => $prompt,
],
],
]);
];

if (! $results->ok()) {
$payload = $this->addJsonFormat($payload);

$results = $this->getClient()->post('/messages', $payload);

if ($results->failed()) {
$error = $results->json()['error']['type'];
$message = $results->json()['error']['message'];
Log::error('Claude API Error Chat', [
Expand All @@ -112,6 +116,12 @@ public function completion(string $prompt): CompletionResponse
]);
}

public function addJsonFormat(array $payload): array
{
//not available for Claude
return $payload;
}

/**
* @return CompletionResponse[]
*
Expand All @@ -133,31 +143,35 @@ public function completionPool(array $prompts, int $temperature = 0): array
$model,
$maxTokens) {
foreach ($prompts as $prompt) {
$payload = [
'model' => $model,
'max_tokens' => $maxTokens,
'messages' => [
[
'role' => 'user',
'content' => $prompt,
],
],
];

$payload = $this->addJsonFormat($payload);

$pool->retry(3, 6000)->withHeaders([
'x-api-key' => $api_token,
'anthropic-beta' => 'tools-2024-04-04',
'anthropic-version' => $this->version,
'content-type' => 'application/json',
])->baseUrl($this->baseUrl)
->timeout(240)
->post('/messages', [
'model' => $model,
'max_tokens' => $maxTokens,
'messages' => [
[
'role' => 'user',
'content' => $prompt,
],
],
]);
->post('/messages', $payload);
}

});

$results = [];

foreach ($responses as $index => $response) {
if ($response->ok()) {
if ($response->successful()) {
foreach ($response->json()['content'] as $content) {
$results[] = CompletionResponse::from([
'content' => $content['text'],
Expand Down
172 changes: 126 additions & 46 deletions Modules/LlmDriver/app/Functions/ReportingTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@

namespace LlmLaraHub\LlmDriver\Functions;

use App\Domains\Messages\RoleEnum;
use App\Domains\Prompts\ReportBuildingFindRequirementsPrompt;
use App\Domains\Prompts\ReportingSummaryPrompt;
use App\Domains\Reporting\ReportTypeEnum;
use App\Models\Document;
use App\Models\Message;
use App\Models\Report;
use App\Models\Section;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\FunctionResponse;
use LlmLaraHub\LlmDriver\ToolsHelper;

class ReportingTool extends FunctionContract
{
use ToolsHelper;

protected string $name = 'reporting_tool';

protected string $description = 'Uses Reference collection to generate a report';
Expand All @@ -31,6 +38,8 @@ public function handle(
//make or update a reports table for this message - chat
//gather all the documents
//and for each document
//GOT TO TRIGGER THE PROMPT TO BE RELATIVE
// like a summary of the goal or something
//build up a list of sections that are requests (since this is a flexible tool that will be part of a prompt
//then save each one with a reference to the document, chunk to the sections table
//then for each section review each related collections solutions to make numerous
Expand All @@ -56,73 +65,144 @@ public function handle(

$this->results = [];

foreach ($documents->chunk(3) as $index => $databaseChunk) {
try {
$prompts = [];
$documents = [];

foreach ($databaseChunk as $document) {
$documents[] = $document;
$content = $document->document_chunks->pluck('content')->toArray();

$content = implode("\n", $content);

/**
* @NOTE
* This assumes a small amount of incoming content to check
* The user my upload a blog post that is 20 paragraphs or more.
*/
$prompt = ReportBuildingFindRequirementsPrompt::prompt(
$content, $message->getContent(), $message->getChatable()->description
);
$this->promptHistory[] = $prompt;
$prompts[] = $prompt;

}
$sectionContent = [];

$results = LlmDriverFacade::driver($message->getDriver())
->completionPool($prompts);

foreach ($results as $result) {
//make the sections per the results coming back.
$content = $result->content;
$content = json_decode($content, true);
foreach ($content as $sectionIndex => $sectionText) {
$title = data_get($sectionText, 'title', 'NOT TITLE GIVEN');
$content = data_get($sectionText, 'content', 'NOT CONTENT GIVEN');

$section = Section::updateOrCreate([
'document_id' => $document->id,
'report_id' => $report->id,
'sort_order' => $sectionIndex,
], [
'subject' => $title,
'content' => $content,
]);
}
Log::info('[LaraChain] - Reporting Tool', [
'documents_count' => count($documents),
]);

$this->results[] = $section->content;
foreach ($documents as $index => $document) {
try {

$groupedChunks = $document->document_chunks
->sortBy('sort_order')
->groupBy('sort_order');

$pagesGrouped = $groupedChunks->map(function ($chunks, $pageNumber) {
return [
"page_$pageNumber" => $chunks->toArray(),
];
})->collapse();

foreach (collect($pagesGrouped)
->chunk(3) as $pageIndex => $pagesChunk) {
$prompts = [];
foreach ($pagesChunk as $index => $page) {
$pageContent = collect($page)->pluck('content')->toArray();
$pageContent = implode("\n", $pageContent);
put_fixture('solutions_prompts_'.$pageIndex.'.txt', $pageContent, false);
$prompt = ReportBuildingFindRequirementsPrompt::prompt(
$pageContent, $message->getContent(), $message->getChatable()->description
);
$prompts[] = $prompt;
put_fixture('solutions_prompts_'.$index.'.txt', $prompts, false);
}
$this->poolPrompt($prompts, $report, $document);
}

} catch (\Exception $e) {
Log::error('Error running Reporting Tool Checker', [
'error' => $e->getMessage(),
'index' => $index,
'line' => $e->getLine(),
]);
}
}

notify_ui($message->getChat(), 'Wow that was a lot of document!');
notify_ui($message->getChat(), 'Wow that was a lot of document! Now to finalize the output');

$response = $this->summarizeReport($report);

$assistantMessage = $message->getChat()->addInput(
message: $response->content,
role: RoleEnum::Assistant,
systemPrompt: $message->getChat()->getChatable()->systemPrompt(),
show_in_thread: true,
meta_data: $message->meta_data,
tools: $message->tools
);

$this->savePromptHistory($assistantMessage,
implode("\n", $this->promptHistory));

$report->message_id = $assistantMessage->id;
$report->save();

//as a final output
//get deadlines
//get contacts

return FunctionResponse::from([
'content' => implode('\n', $this->results),
'content' => $response->content,
'prompt' => implode('\n', $this->promptHistory),
'requires_followup' => false,
'documentChunks' => collect([]),
'save_to_message' => false,
]);
}

protected function poolPrompt(array $prompts, Report $report, Document $document): void
{
$results = LlmDriverFacade::driver($report->getDriver())
->completionPool($prompts);
foreach ($results as $resultIndex => $result) {
//make the sections per the results coming back.
$content = $result->content;
put_fixture('solutions_content_'.$resultIndex.'.txt', $content, false);
$this->makeSectionFromContent($content, $document, $report);
}
}

protected function summarizeReport(Report $report): CompletionResponse
{
$sectionContent = $report->refresh()->sections->pluck('content')->toArray();
$sectionContent = implode("\n", $sectionContent);

$prompt = ReportingSummaryPrompt::prompt($sectionContent);

$this->promptHistory = [$prompt];

/** @var CompletionResponse $response */
$response = LlmDriverFacade::driver(
$report->getChatable()->getDriver()
)->completion($prompt);

return $response;
}

protected function makeSectionFromContent(
string $content,
Document $document,
Report $report): void
{
try {
$content = str($content)
->remove('```json')
->remove('```')
->toString();
$contentDecoded = json_decode($content, true);
foreach ($contentDecoded as $sectionIndex => $sectionText) {
$title = data_get($sectionText, 'title', 'NOT TITLE GIVEN');
$contentBody = data_get($sectionText, 'content', 'NOT CONTENT GIVEN');
Section::updateOrCreate([
'document_id' => $document->id,
'report_id' => $report->id,
'sort_order' => $report->refresh()->sections->count() + 1,
], [
'subject' => $title,
'content' => $contentBody,
]);
}
} catch (\Exception $e) {
put_fixture('solutions_content_error.txt', $content, false);
Log::error('Error parsing JSON', [
'error' => $e->getMessage(),
'content' => $content,
'line' => $e->getLine(),
]);
}
}

/**
* @return PropertyDto[]
*/
Expand Down
8 changes: 1 addition & 7 deletions Modules/LlmDriver/app/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,7 @@ public function handle(
tools: $message->tools);

if ($response->prompt) {
PromptHistory::create([
'prompt' => $response->prompt,
'chat_id' => $chat->getChat()->id,
'message_id' => $assistantMessage?->id,
/** @phpstan-ignore-next-line */
'collection_id' => $chat->getChatable()?->id,
]);
$this->savePromptHistory($message, $response->prompt);
}

if (! empty($response->documentChunks) && $assistantMessage?->id) {
Expand Down
12 changes: 12 additions & 0 deletions Modules/LlmDriver/app/ToolsHelper.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use App\Domains\Chat\ToolsDto;
use App\Models\Message;
use App\Models\PromptHistory;
use LlmLaraHub\LlmDriver\Functions\FunctionCallDto;

trait ToolsHelper
Expand All @@ -19,4 +20,15 @@ protected function addToolsToMessage(Message $message, FunctionCallDto $function

return $message->refresh();
}

protected function savePromptHistory(Message $message, string $prompt): void
{
PromptHistory::create([
'prompt' => $prompt,
'chat_id' => $message->getChat()->id,
'message_id' => $message?->id,
/** @phpstan-ignore-next-line */
'collection_id' => $message->getChatable()?->id,
]);
}
}
2 changes: 1 addition & 1 deletion Modules/LlmDriver/tests/Feature/ClaudeClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public function test_completion_pool(): void

Http::preventStrayRequests();

$results = $client->completionPool([
$results = $client->setFormatJson()->completionPool([
'test1',
'test2',
'test3',
Expand Down
1 change: 0 additions & 1 deletion app/Domains/Messages/SearchAndSummarizeChatRepo.php
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ public function search(
$latestMessagesArray = $assistantMessage->getLatestMessages();

Log::info('[LaraChain] Getting the Summary', [
'input' => $contentFlattened,
'driver' => $chat->chatable->getDriver(),
'messages' => count($latestMessagesArray),
]);
Expand Down
Loading

0 comments on commit 8655ee3

Please sign in to comment.