Skip to content

Commit

Permalink
This now adds the gathering info prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Jul 25, 2024
1 parent b406feb commit de859e2
Show file tree
Hide file tree
Showing 22 changed files with 595 additions and 13 deletions.
140 changes: 140 additions & 0 deletions Modules/LlmDriver/app/Functions/GatherInfoTool.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
<?php

namespace LlmLaraHub\LlmDriver\Functions;

use App\Domains\Prompts\PromptMerge;
use App\Domains\Reporting\ReportTypeEnum;
use App\Domains\Reporting\StatusEnum;
use App\Jobs\GatherInfoFinalPromptJob;
use App\Jobs\GatherInfoReportSectionsJob;
use App\Models\Collection;
use App\Models\Message;
use App\Models\Report;
use Illuminate\Bus\Batch;
use Illuminate\Support\Facades\Bus;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Responses\FunctionResponse;
use LlmLaraHub\LlmDriver\ToolsHelper;

class GatherInfoTool extends FunctionContract
{
use ToolsHelper;

protected string $name = 'gather_info_tool';

protected string $description = 'This will look at all documents using your prompt then return the results after once more using your prompt';

protected string $response = '';

protected array $results = [];

protected array $sectionJobs = [];

public function handle(
Message $message): FunctionResponse
{
Log::info('[LaraChain] GatherInfoTool Function called');

$report = Report::firstOrCreate([
'chat_id' => $message->getChat()->id,
'message_id' => $message->id,
'reference_collection_id' => $message->getReferenceCollection()?->id,
'user_id' => $message->getChat()->user_id,
], [
'type' => ReportTypeEnum::GatherInfo,
'user_message_id' => $message->id,
'status_sections_generation' => StatusEnum::Pending,
'status_entries_generation' => StatusEnum::Pending,
]);

$collection = $message->getChatable();

notify_ui($message->getChat(), 'Going through all the documents to check requirements');

$this->results = [];

Log::info('[LaraChain] - GatherInfo Tool');

$this->buildUpSections($collection, $report, $message);

Bus::batch($this->sectionJobs)
->name(sprintf('GatherInfo Tool Sections Report id: %d Chat id %d', $report->id, $message->getChat()->id))
->allowFailures()
->finally(function (Batch $batch) use ($report) {
$report->update(['status_sections_generation' => StatusEnum::Complete]);
Bus::batch([
new GatherInfoFinalPromptJob($report),
])->name(sprintf('Reporting Tool Summarize Report Id %s', $report->id))
->allowFailures()
->dispatch();

})
->dispatch();

$report->update([
'status_sections_generation' => StatusEnum::Running,
]);

notify_ui($report->getChat(), 'Running');

return FunctionResponse::from([
'content' => 'Gathering info and then running prompt',
'prompt' => $message->getPrompt(),
'requires_followup' => false,
'documentChunks' => collect([]),
'save_to_message' => false,
]);
}

protected function buildUpSections(Collection $collection, Report $report, Message $message): void
{
$collection->documents()->chunk(3, callback: function ($documentChunks) use ($message, $report) {
try {
$prompts = [];
foreach ($documentChunks as $document) {
$prompt = PromptMerge::merge(
['[CONTEXT]'],
[$document->original_content],
$message->getPrompt()
);

$prompts[] = $prompt;
}

if ($document?->id) {
$this->sectionJobs[] =
new GatherInfoReportSectionsJob(
prompts: $prompts,
report: $report,
document: $document);
}

} catch (\Exception $e) {
Log::error('Error running Reporting Tool Checker', [
'error' => $e->getMessage(),
'line' => $e->getLine(),
]);
}
});
}

/**
* @return PropertyDto[]
*/
protected function getProperties(): array
{
return [
new PropertyDto(
name: 'prompt',
description: 'Using your prompt we will look at every document, run your prompt against each one and then against the final output',
type: 'string',
required: true,
),
];
}

public function runAsBatch(): bool
{
return true;
}
}
69 changes: 69 additions & 0 deletions Modules/LlmDriver/app/Functions/GatherInfoToolMakeSections.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
<?php

namespace LlmLaraHub\LlmDriver\Functions;

use App\Models\Document;
use App\Models\Report;
use App\Models\Section;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\LlmDriverFacade;

class GatherInfoToolMakeSections
{
public function handle(
array $prompts,
Report $report,
Document $document
) {

$this->poolPrompt($prompts, $report, $document);

}

protected function poolPrompt(array $prompts, Report $report, Document $document): void
{
notify_ui($report->getChat(), 'Running Prompt against gathered info');
notify_ui_report($report, 'Running Prompt against gathered info');

Log::info('LlmDriver::GatherInfoToolMakeSections::poolPrompt', [
'driver' => $report->getDriver(),
'prompts' => $prompts,
]);

$results = LlmDriverFacade::driver($report->getDriver())
->completionPool($prompts);

foreach ($results as $resultIndex => $result) {
$content = $result->content;
$this->makeSectionFromContent($content, $document, $report);
}

notify_ui($report->getChat(), 'Done gathering info');
notify_ui_report($report, 'Done gathering info');
}

protected function makeSectionFromContent(
string $content,
Document $document,
Report $report): void
{
try {

Section::updateOrCreate([
'document_id' => $document->id,
'report_id' => $report->id,
'sort_order' => $report->refresh()->sections->count() + 1,
], [
'subject' => str($content)->limit(128)->toString(),
'content' => $content,
]);

} catch (\Exception $e) {
Log::error('Error creating section', [
'error' => $e->getMessage(),
'content' => $content,
'line' => $e->getLine(),
]);
}
}
}
1 change: 1 addition & 0 deletions Modules/LlmDriver/app/Functions/ReportingTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public function handle(
'user_id' => $message->getChat()->user_id,
], [
'type' => ReportTypeEnum::RFP,
'user_message_id' => $message->id,
'status_sections_generation' => StatusEnum::Pending,
'status_entries_generation' => StatusEnum::Pending,
]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public function handle(Report $report)
$this->savePromptHistory($assistantMessage,
implode("\n", $this->promptHistory));

$report->user_message_id = $report->message_id;
$report->message_id = $assistantMessage->id;
$report->save();

Expand Down
2 changes: 2 additions & 0 deletions Modules/LlmDriver/app/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace LlmLaraHub\LlmDriver;

use LlmLaraHub\LlmDriver\Functions\GatherInfoTool;
use LlmLaraHub\LlmDriver\Functions\ReportingTool;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
Expand Down Expand Up @@ -61,6 +62,7 @@ public function getFunctions(): array
(new SearchAndSummarize())->getFunction(),
(new StandardsChecker())->getFunction(),
(new ReportingTool())->getFunction(),
(new GatherInfoTool())->getFunction(),
];
}

Expand Down
5 changes: 5 additions & 0 deletions Modules/LlmDriver/app/LlmServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Illuminate\Support\Facades\Log;
use Illuminate\Support\ServiceProvider;
use LlmLaraHub\LlmDriver\DistanceQuery\DistanceQueryClient;
use LlmLaraHub\LlmDriver\Functions\GatherInfoTool;
use LlmLaraHub\LlmDriver\Functions\ReportingTool;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
Expand Down Expand Up @@ -73,6 +74,10 @@ public function boot(): void
return new ReportingTool();
});

$this->app->bind('gather_info_tool', function () {
return new GatherInfoTool();
});

}

/**
Expand Down
1 change: 1 addition & 0 deletions app/Domains/Reporting/ReportTypeEnum.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
enum ReportTypeEnum: string
{
case RFP = 'rfp';
case GatherInfo = 'gather_info';
}
1 change: 1 addition & 0 deletions app/Http/Resources/ReportResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public function toArray(Request $request): array
'user_id' => $this->user_id,
'chat_id' => $this->chat_id,
'type' => $this->type,
'type_formatted' => str($this->type->name)->headline()->toString(),
'reference_collection' => new CollectionResource($this->reference_collection),
'sections' => SectionResource::collection($this->sections),
'status_sections_generation' => $this->status_sections_generation?->value,
Expand Down
14 changes: 7 additions & 7 deletions app/Http/Resources/UserResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class UserResource extends JsonResource
*/
public function toArray(Request $request): array
{
return [
'id' => $this->id,
'name' => $this->name,
'email' => $this->email,
'email_verified_at' => $this->email_verified_at,
'current_team_id' => $this->current_team_id,
];
return [
'id' => $this->id,
'name' => $this->name,
'email' => $this->email,
'email_verified_at' => $this->email_verified_at,
'current_team_id' => $this->current_team_id,
];
}
}
95 changes: 95 additions & 0 deletions app/Jobs/GatherInfoFinalPromptJob.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
<?php

namespace App\Jobs;

use App\Domains\Messages\RoleEnum;
use App\Domains\Reporting\StatusEnum;
use Illuminate\Bus\Batchable;
use Illuminate\Bus\Queueable;
use Illuminate\Contracts\Queue\ShouldQueue;
use Illuminate\Foundation\Bus\Dispatchable;
use Illuminate\Queue\InteractsWithQueue;
use Illuminate\Queue\SerializesModels;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\ToolsHelper;

class GatherInfoFinalPromptJob implements ShouldQueue
{
use Batchable, ToolsHelper;
use Dispatchable, InteractsWithQueue, Queueable, SerializesModels;

/**
* Create a new job instance.
*/
public function __construct(
public \App\Models\Report $report
) {
//
}

/**
* Execute the job.
*/
public function handle(): void
{
//@NOTE
// if completion does not work I can make this a
// chat with all the results
//ok we have all the sections lets make a final prompt output
//
$messages = [];

$history = [];

$messages[] = MessageInDto::from([
'content' => 'You are an assistant helping with this gathered info.',
'role' => 'system',
]);

foreach ($this->report->sections as $section) {
$messages[] = MessageInDto::from([
'content' => $section->content,
'role' => 'user',
]);

$history[] = $section->content;

$messages[] = MessageInDto::from([
'content' => 'Using the surrounding context to continue this response thread',
'role' => 'assistant',
]);
}

$messages[] = MessageInDto::from([
'content' => sprintf('Using the context of this chat can you '.
$this->report->message->getPrompt()),
'role' => 'user',
]);

$response = LlmDriverFacade::driver($this->report->getDriver())
->chat($messages);

$assistantMessage = $this->report->getChat()->addInput(
message: $response->content,
role: RoleEnum::Assistant,
systemPrompt: 'You are an assistant helping with this gathered info.',
show_in_thread: true,
meta_data: $this->report->message->meta_data,
tools: $this->report->message->tools
);

$this->report->user_message_id = $this->report->message_id;
$this->report->message_id = $assistantMessage->id;
$this->report->status_entries_generation = StatusEnum::Complete;
$this->report->save();

$this->savePromptHistory($assistantMessage,
implode("\n", $history));

notify_ui($this->report->getChat(), 'Building Solutions list');
notify_ui_report($this->report, 'Building Solutions list');
notify_ui_complete($this->report->getChat());

}
}
Loading

0 comments on commit de859e2

Please sign in to comment.