Skip to content

Commit

Permalink
move the search and or summaize to the main chat area so Ollama can d…
Browse files Browse the repository at this point in the history
…o that now
  • Loading branch information
alnutile committed May 16, 2024
1 parent 161a616 commit 31915be
Show file tree
Hide file tree
Showing 15 changed files with 275 additions and 78 deletions.
52 changes: 35 additions & 17 deletions Modules/LlmDriver/app/NonFunctionSearchOrSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
use App\Models\DocumentChunk;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Responses\NonFunctionResponseDto;

class NonFunctionSearchOrSummarize
{
protected string $results = "";

public function handle(string $input, Collection $collection) : string
public function handle(string $input, HasDrivers $collection): NonFunctionResponseDto
{
$collection = $collection->getChatable();

if (! get_class($collection) === Collection::class) {
throw new \Exception('Can only do Collection class right now');
}

Log::info("[LaraChain] - Using the Non Function Search and Summarize Prompt", [
Log::info('[LaraChain] - Using the Non Function Search and Summarize Prompt', [
'collection' => $collection->id,
'input' => $input
'input' => $input,
]);

$prompt = SearchOrSummarize::prompt($input);
Expand All @@ -35,7 +39,7 @@ public function handle(string $input, Collection $collection) : string

if (str($response->content)->contains('search')) {
Log::info('[LaraChain] - LLM Thinks it is Search', [
'response' => $response->content]
'response' => $response->content]
);

$embedding = LlmDriverFacade::driver(
Expand Down Expand Up @@ -78,10 +82,16 @@ public function handle(string $input, Collection $collection) : string
$collection->getDriver()
)->completion($contentFlattened);

$this->results = $response->content;
return NonFunctionResponseDto::from(
[
'response' => $response->content,
'documentChunks' => $documentChunkResults,
'prompt' => $contentFlattened,
]
);
} elseif (str($response->content)->contains('summarize')) {
Log::info('[LaraChain] - LLM Thinks it is summarize', [
'response' => $response->content]
'response' => $response->content]
);

$content = [];
Expand All @@ -94,8 +104,8 @@ public function handle(string $input, Collection $collection) : string
$contentFlattened = implode(' ', $content);

Log::info('[LaraChain] - Documents Flattened', [
'collection' => $collection->id,
'content' => $content]
'collection' => $collection->id,
'content' => $content]
);

$prompt = SummarizeDocumentPrompt::prompt($contentFlattened);
Expand All @@ -104,11 +114,16 @@ public function handle(string $input, Collection $collection) : string
$collection->getDriver()
)->completion($prompt);


$this->results = $response->content;
return NonFunctionResponseDto::from(
[
'response' => $response->content,
'documentChunks' => collect(),
'prompt' => $prompt,
]
);
} else {
Log::info('[LaraChain] - LLM is not sure :(', [
'response' => $response->content]
'response' => $response->content]
);

$embedding = LlmDriverFacade::driver(
Expand Down Expand Up @@ -146,11 +161,14 @@ public function handle(string $input, Collection $collection) : string
$collection->getDriver()
)->completion($contentFlattened);


$this->results = $response->content;
return NonFunctionResponseDto::from(
[
'response' => $response->content,
'documentChunks' => collect(),
'prompt' => $contentFlattened,
]
);

}

return $this->results;
}
}
17 changes: 17 additions & 0 deletions Modules/LlmDriver/app/Responses/NonFunctionResponseDto.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Illuminate\Support\Collection;
use Spatie\LaravelData\Data;

class NonFunctionResponseDto extends Data
{
public function __construct(
public Collection $documentChunks,
public string $response = '',
public string $prompt = '',
) {

}
}
17 changes: 8 additions & 9 deletions Modules/LlmDriver/tests/Feature/ClaudeClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ public function test_chat(): void
$this->assertInstanceOf(CompletionResponse::class, $results);

Http::assertSent(function ($request) {
$message1 = $request->data()['messages'][0]['role'];
$message2 = $request->data()['messages'][1]['role'];
$messageAssistant = $request->data()['messages'][0]['role'];
$messageUser = $request->data()['messages'][1]['role'];

return $message2 === 'assistant' &&
$message1 === 'user';
return $messageAssistant === 'assistant' &&
$messageUser === 'user';
});

}
Expand Down Expand Up @@ -109,12 +109,11 @@ public function test_chat_with_multiple_assistant_messages(): void
$this->assertInstanceOf(CompletionResponse::class, $results);

Http::assertSent(function ($request) {
$message0 = $request->data()['messages'][0]['role'];
$message1 = $request->data()['messages'][1]['role'];
$message2 = $request->data()['messages'][2]['role'];
$messageAssistant = $request->data()['messages'][1]['role'];
$messageUser = $request->data()['messages'][2]['role'];

return $message0 === 'assistant' &&
$message1 === 'user' && $message2 === 'assistant';
return $messageAssistant === 'assistant' &&
$messageUser === 'user';
});

}
Expand Down
2 changes: 1 addition & 1 deletion app/Domains/Prompts/SearchOrSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class SearchOrSummarize
public static function prompt(string $originalPrompt): string
{

Log::info('[LaraChain] - Search or SearchAndSummarize');
Log::info('[LaraChain] - SearchOrSummarize Prompt');

return <<<PROMPT
### Task, Action, Goal (T.A.G)
Expand Down
12 changes: 4 additions & 8 deletions app/Http/Controllers/WebPageOutputController.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@
use App\Domains\Collections\CollectionStatusEnum;
use App\Domains\Outputs\OutputTypeEnum;
use App\Domains\Prompts\AnonymousChat;
use App\Domains\Prompts\DefaultPrompt;
use App\Domains\Prompts\SearchOrSummarize;
use App\Domains\Prompts\SummarizeDocumentPrompt;
use App\Domains\Prompts\SummarizeForPage;
use App\Domains\Prompts\SummarizePrompt;
use App\Http\Resources\CollectionResource;
use App\Http\Resources\PublicOutputResource;
use App\Models\Collection;
use App\Models\DocumentChunk;
use App\Models\Output;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use Facades\LlmLaraHub\LlmDriver\NonFunctionSearchOrSummarize;
use Illuminate\Support\Arr;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Helpers\TrimText;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use Facades\LlmLaraHub\LlmDriver\NonFunctionSearchOrSummarize;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\NonFunctionResponseDto;

class WebPageOutputController extends Controller
{
Expand Down Expand Up @@ -86,9 +81,10 @@ public function chat(Output $output)
'message' => $validated['input']]
);

/** @var NonFunctionResponseDto $results */
$results = NonFunctionSearchOrSummarize::handle($input, $output->collection);

$this->setChatMessages($results, 'assistant');
$this->setChatMessages($results->response, 'assistant');

return back();
}
Expand Down
36 changes: 36 additions & 0 deletions app/Jobs/DocumentReferenceJob.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?php

namespace App\Jobs;

use App\Models\Message;
use Illuminate\Bus\Queueable;
use Illuminate\Contracts\Queue\ShouldQueue;
use Illuminate\Foundation\Bus\Dispatchable;
use Illuminate\Queue\InteractsWithQueue;
use Illuminate\Queue\SerializesModels;
use Illuminate\Support\Collection;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;

class DocumentReferenceJob implements ShouldQueue
{
use CreateReferencesTrait;
use Dispatchable, InteractsWithQueue, Queueable, SerializesModels;

/**
* Create a new job instance.
*/
public function __construct(
public Message $message,
public Collection $documentChunks
) {
//
}

/**
* Execute the job.
*/
public function handle(): void
{
$this->saveDocumentReference($this->message, $this->documentChunks);
}
}
53 changes: 49 additions & 4 deletions app/Jobs/SimpleSearchAndSummarizeOrchestrateJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@

namespace App\Jobs;

use App\Models\Chat;
use Facades\LlmLaraHub\LlmDriver\SimpleSearchAndSummarizeOrchestrate;
use App\Domains\Messages\RoleEnum;
use App\Models\Collection;
use App\Models\PromptHistory;
use Facades\LlmLaraHub\LlmDriver\NonFunctionSearchOrSummarize;
use Illuminate\Bus\Queueable;
use Illuminate\Contracts\Queue\ShouldQueue;
use Illuminate\Foundation\Bus\Dispatchable;
use Illuminate\Queue\InteractsWithQueue;
use Illuminate\Queue\SerializesModels;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
use LlmLaraHub\LlmDriver\Responses\NonFunctionResponseDto;

class SimpleSearchAndSummarizeOrchestrateJob implements ShouldQueue
{
use CreateReferencesTrait;
use Dispatchable, InteractsWithQueue, Queueable, SerializesModels;

/**
* Create a new job instance.
*/
public function __construct(public string $input, public Chat $chat)
public function __construct(public string $input, public HasDrivers $chat)
{
//
}
Expand All @@ -27,6 +34,44 @@ public function __construct(public string $input, public Chat $chat)
*/
public function handle(): void
{
SimpleSearchAndSummarizeOrchestrate::handle($this->input, $this->chat);
Log::info('[LaraChain] Skipping over functions doing search and summarize');

notify_ui(
$this->chat->getChatable(),
'Searching data now to summarize content'
);

$collection = $this->chat->getChatable();

if (get_class($collection) === Collection::class) {
/** @var NonFunctionResponseDto $results */
$results = NonFunctionSearchOrSummarize::handle($this->input, $collection);

$message = $this->chat->getChat()->addInput(
message: $results->response,
role: RoleEnum::Assistant,
show_in_thread: true
);

if ($results->prompt) {
PromptHistory::create([
'prompt' => $results->prompt,
'chat_id' => $this->chat->getChat()->id,
'message_id' => $message->id,
/** @phpstan-ignore-next-line */
'collection_id' => $this->chat->getChatable()?->id,
]);
}

if ($results->documentChunks->isNotEmpty()) {
$this->saveDocumentReference(
$message,
$results->documentChunks
);
}
} else {
Log::info('Can only handle Collection model right now');
}

}
}
11 changes: 9 additions & 2 deletions tests/Feature/Http/Controllers/ChatControllerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
use App\Models\Message;
use App\Models\User;
use Facades\App\Domains\Agents\VerifyResponseAgent;
use Facades\LlmLaraHub\LlmDriver\NonFunctionSearchOrSummarize;
use Facades\LlmLaraHub\LlmDriver\Orchestrate;
use Facades\LlmLaraHub\LlmDriver\SimpleSearchAndSummarizeOrchestrate;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\NonFunctionResponseDto;
use Tests\TestCase;

class ChatControllerTest extends TestCase
Expand Down Expand Up @@ -133,7 +134,13 @@ public function test_no_functions()
]);

LlmDriverFacade::shouldReceive('driver->hasFunctions')->once()->andReturn(false);
SimpleSearchAndSummarizeOrchestrate::shouldReceive('handle')->once()->andReturn('Yo');

NonFunctionSearchOrSummarize::shouldReceive('handle')->once()->andReturn(
NonFunctionResponseDto::from([
'response' => 'Foobar',
'documentChunks' => collect(),
'prompt' => 'Foobar',
]));

$this->actingAs($user)->post(route('chats.messages.create', [
'chat' => $chat->id,
Expand Down
20 changes: 15 additions & 5 deletions tests/Feature/Http/Controllers/WebPageOutputControllerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
use App\Models\Output;
use App\Models\User;
use Facades\LlmLaraHub\LlmDriver\DistanceQuery;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use Facades\LlmLaraHub\LlmDriver\NonFunctionSearchOrSummarize;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
use Pgvector\Laravel\Vector;
use LlmLaraHub\LlmDriver\Responses\NonFunctionResponseDto;
use Tests\TestCase;

class WebPageOutputControllerTest extends TestCase
Expand Down Expand Up @@ -71,7 +70,12 @@ public function test_chat_search()
'public' => true,
]);
NonFunctionSearchOrSummarize::shouldReceive('handle')
->once()->andReturn("Foo");
->once()->andReturn(
NonFunctionResponseDto::from([
'response' => 'Foobar',
'documentChunks' => collect(),
'prompt' => 'Foobar',
]));

$this->post(route(
'collections.outputs.web_page.chat', [
Expand All @@ -85,7 +89,13 @@ public function test_chat_search()
public function test_no_search_no_summary()
{
NonFunctionSearchOrSummarize::shouldReceive('handle')
->once()->andReturn("Foo");
->once()->andReturn(
NonFunctionResponseDto::from([
'response' => 'Foobar',
'documentChunks' => collect(),
'prompt' => 'Foobar',
])
);

$output = Output::factory()->create([
'active' => true,
Expand Down
Loading

0 comments on commit 31915be

Please sign in to comment.