Skip to content

Commit

Permalink
omg this will simplify a lot of things but it sure is making a lot of…
Browse files Browse the repository at this point in the history
… changes
  • Loading branch information
alnutile committed Aug 2, 2024
1 parent 87dd619 commit 076d53b
Show file tree
Hide file tree
Showing 27 changed files with 154 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\FunctionResponse;

class SearchAndSummarize extends FunctionContract
class RetrieveRelated extends FunctionContract
{
use CreateReferencesTrait;

protected string $name = 'search_and_summarize';
protected string $name = 'retrieve_related';

protected string $description = 'Used to embed users prompt, search local database and return summarized results.
DOES NOT SEARCH THE WEB';
DOES NOT SEARCH THE WEB. This is only used for local database search.';

protected string $response = '';

public function handle(
Message $message): FunctionResponse
{
Log::info('[LaraChain] Using Function: SearchAndSummarize');
Log::info('[LaraChain] Using Function: RetrieveRelated');

/**
* @TODO
Expand Down Expand Up @@ -61,7 +61,7 @@ public function handle(

/**
* @NOTE
* Yes this is a lot like the SearchAndSummarizeChatRepo
* Yes this is a lot like the RetrieveRelatedChatRepo
* But just getting a sense of things
*/
foreach ($documentChunkResults as $result) {
Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use LlmLaraHub\LlmDriver\Functions\GatherInfoTool;
use LlmLaraHub\LlmDriver\Functions\GetWebSiteFromUrlTool;
use LlmLaraHub\LlmDriver\Functions\ReportingTool;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\RetrieveRelated;
use LlmLaraHub\LlmDriver\Functions\SearchTheWeb;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
use LlmLaraHub\LlmDriver\Functions\SummarizeCollection;
Expand Down Expand Up @@ -62,7 +62,7 @@ public function getFunctions(): array
{
return [
(new SummarizeCollection())->getFunction(),
(new SearchAndSummarize())->getFunction(),
(new RetrieveRelated())->getFunction(),
(new StandardsChecker())->getFunction(),
(new ReportingTool())->getFunction(),
(new GatherInfoTool())->getFunction(),
Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/LlmServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use LlmLaraHub\LlmDriver\Functions\GatherInfoTool;
use LlmLaraHub\LlmDriver\Functions\GetWebSiteFromUrlTool;
use LlmLaraHub\LlmDriver\Functions\ReportingTool;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\RetrieveRelated;
use LlmLaraHub\LlmDriver\Functions\SearchTheWeb;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
use LlmLaraHub\LlmDriver\Functions\SummarizeCollection;
Expand Down Expand Up @@ -66,7 +66,7 @@ public function boot(): void
});

$this->app->bind('search_and_summarize', function () {
return new SearchAndSummarize();
return new RetrieveRelated();
});

$this->app->bind('standards_checker', function () {
Expand Down
3 changes: 1 addition & 2 deletions Modules/LlmDriver/app/OllamaClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ public function completion(string $prompt): CompletionResponse
'stream' => false,
]);

$results = $response->json()['response'];

return new CompletionResponse($results);
return OllamaCompletionResponse::from($response->json());
}

protected function getClient()
Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use App\Models\Filter;
use App\Models\Message;
use App\Models\PromptHistory;
use Facades\App\Domains\Messages\SearchAndSummarizeChatRepo;
use Facades\App\Domains\Messages\RetrieveRelatedChatRepo;
use Illuminate\Support\Facades\Bus;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Functions\FunctionCallDto;
Expand Down Expand Up @@ -173,7 +173,7 @@ public function handle(
} else {
Log::info('[LaraChain] Orchestration No Functions Default Search And Summarize');

return SearchAndSummarizeChatRepo::search($chat, $message);
return RetrieveRelatedChatRepo::search($chat, $message);
}
}

Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/Responses/ClaudeCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ public function __construct(
#[WithCastable(ClaudeContentCaster::class)]
public mixed $content,
public string|Optional $stop_reason,
public string|Optional $tool_used,
public string|null $tool_used = "",
/** @var array<ToolDto> */
#[WithCastable(ClaudeToolCaster::class)]
#[MapInputName('content')]
public array|Optional $tool_calls,
public array $tool_calls = [],
#[MapInputName('usage.input_tokens')]
public ?int $input_tokens = null,
#[MapInputName('usage.output_tokens')]
Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/Responses/CompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class CompletionResponse extends \Spatie\LaravelData\Data
public function __construct(
public mixed $content,
public string|Optional $stop_reason,
public string|Optional $tool_used,
public string|null $tool_used = "",
/** @var array<ToolDto> */
public array|Optional $tool_calls,
public array $tool_calls = [],
public ?int $input_tokens = null,
public ?int $output_tokens = null,
public ?string $model = null,
Expand Down
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/Responses/OllamaCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ public function __construct(
public mixed $content,
#[MapInputName('done_reason')]
public string|Optional $stop_reason,
public string|Optional $tool_used,
public string|null $tool_used = "",
/** @var array<OllamaToolDto> */
#[MapInputName('message.tool_calls')]
public array|Optional $tool_calls,
public array $tool_calls = [],
#[MapInputName('prompt_eval_count')]
public ?int $input_tokens = null,
#[MapInputName('eval_count')]
Expand Down
6 changes: 3 additions & 3 deletions Modules/LlmDriver/app/SimpleSearchAndSummarizeOrchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
namespace LlmLaraHub\LlmDriver;

use App\Models\Chat;
use Facades\App\Domains\Messages\SearchAndSummarizeChatRepo;
use Facades\App\Domains\Messages\RetrieveRelatedChatRepo;
use Illuminate\Support\Facades\Log;

class SimpleSearchAndSummarizeOrchestrate
class SimpleRetrieveRelatedOrchestrate
{
protected string $response = '';

Expand All @@ -20,7 +20,7 @@ public function handle(string $message, Chat $chat): ?string
$chat->chatable,
'Searching data now to summarize content'
);
$response = SearchAndSummarizeChatRepo::search($chat, $message);
$response = RetrieveRelatedChatRepo::search($chat, $message);

return $response;
}
Expand Down
19 changes: 8 additions & 11 deletions Modules/LlmDriver/tests/Feature/ClaudeClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -321,25 +321,22 @@ public function test_remap_messages_with_tools_as_history()
$messages[] = MessageInDto::from([
'content' => 'test3',
'role' => RoleEnum::Tool->value,
'meta_data' => MetaDataDto::from([
'tool' => 'test',
'tool_id' => 'test_id',
]),
'tool' => 'test',
'tool_id' => 'test_id',
'meta_data' => MetaDataDto::from([]),
]);

$results = (new ClaudeClient)->remapMessages($messages);

$this->assertCount(3, $results);
$this->assertCount(5, $results);

$this->assertEquals('user', $results[0]['role']);
$this->assertEquals('assistant', $results[1]['role']);
$this->assertEquals('<thinking>test</thinking>', $results[1]['content'][0]['text']);
$this->assertEquals('<thinking>test3</thinking>', $results[3]['content'][0]['text']);

$this->assertEquals('user', $results[2]['role']);
$this->assertEquals('tool_result', $results[2]['content'][0]['type']);
$this->assertEquals('test', $results[2]['content'][0]['content']);
$this->assertEquals('test_id', $results[2]['content'][0]['tool_use_id']);

$this->assertArrayNotHasKey('tool_id', $results[2]);
$this->assertEquals('tool_use', $results[3]['content'][1]['type']);
$this->assertEquals('test', $results[3]['content'][1]['name']);
$this->assertEquals('test_id', $results[3]['content'][1]['id']);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
use LlmLaraHub\LlmDriver\ToolsHelper;

class SearchAndSummarizeChatRepo
class RetrieveRelatedChatRepo
{
use CreateReferencesTrait;
use ToolsHelper;
Expand Down
2 changes: 1 addition & 1 deletion app/Jobs/EmailReplyOutputJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public function handle(): void

/**
* @NOTE
* Yes this is a lot like the SearchAndSummarizeChatRepo
* Yes this is a lot like the RetrieveRelatedChatRepo
* But just getting a sense of things
*/
foreach ($documentChunkResults as $result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* This is used by LLMs that might not have functions
* but still want to do a search and summarize
*/
class SimpleSearchAndSummarizeOrchestrateJob implements ShouldQueue
class SimpleRetrieveRelatedOrchestrateJob implements ShouldQueue
{
use Batchable;
use CreateReferencesTrait;
Expand All @@ -42,7 +42,7 @@ public function __construct(
*/
public function handle(): void
{
Log::info('[LaraChain] Skipping over functions doing SimpleSearchAndSummarizeOrchestrateJob');
Log::info('[LaraChain] Skipping over functions doing SimpleRetrieveRelatedOrchestrateJob');

notify_ui(
$this->message->getChatable(),
Expand Down
3 changes: 2 additions & 1 deletion app/Models/Message.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use App\Events\ChatUiUpdateEvent;
use App\Events\MessageCreatedEvent;
use App\Jobs\OrchestrateJob;
use App\Jobs\SimpleRetrieveRelatedOrchestrateJob;
use App\Jobs\SimpleSearchAndSummarizeOrchestrateJob;
use Facades\App\Domains\Tokenizer\Templatizer;
use Illuminate\Bus\Batch;
Expand Down Expand Up @@ -259,7 +260,7 @@ public function run(): void
} else {
Log::info('[LaraChain] Simple Search and Summarize added to queue');
$this->batchJob([
new SimpleSearchAndSummarizeOrchestrateJob($message),
new SimpleRetrieveRelatedOrchestrateJob($message),
], $chat, 'simple_search_and_summarize');
}

Expand Down
24 changes: 0 additions & 24 deletions tests/Feature/Http/Controllers/ChatControllerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -178,31 +178,7 @@ public function test_kick_off_chat_makes_system()

}

public function test_no_functions()
{
Bus::fake();
$user = User::factory()->create();
$collection = Collection::factory()->create();
$chat = Chat::factory()->create([
'chatable_id' => $collection->id,
'chatable_type' => Collection::class,
'user_id' => $user->id,
]);

LlmDriverFacade::shouldReceive('driver->hasFunctions')->once()->andReturn(false);
LlmDriverFacade::shouldReceive('getFunctionsForUi')->andReturn([]);

$this->actingAs($user)->post(route('chats.messages.create', [
'chat' => $chat->id,
]),
[
'system_prompt' => 'Foo',
'input' => 'user input',
])->assertOk();

Bus::assertBatchCount(1);

}

public function test_standard_checker()
{
Expand Down
1 change: 1 addition & 0 deletions tests/Feature/Models/DocumentTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public function test_factory(): void
public function test_parent()
{
$modelParent = \App\Models\Document::factory()->create();

$model = \App\Models\Document::factory()->create([
'parent_id' => $modelParent->id,
]);
Expand Down
10 changes: 5 additions & 5 deletions tests/Feature/OrchestrateTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
use App\Models\Collection;
use App\Models\Message;
use App\Models\User;
use Facades\App\Domains\Messages\SearchAndSummarizeChatRepo;
use Facades\App\Domains\Messages\RetrieveRelatedChatRepo;
use Illuminate\Support\Facades\Event;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\RetrieveRelated;
use LlmLaraHub\LlmDriver\Functions\StandardsChecker;
use LlmLaraHub\LlmDriver\Functions\SummarizeCollection;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
Expand Down Expand Up @@ -39,7 +39,7 @@ public function test_gets_summarize_function(): void

LlmDriverFacade::shouldReceive('driver->chat')->never();

SearchAndSummarizeChatRepo::shouldReceive('search')->never();
RetrieveRelatedChatRepo::shouldReceive('search')->never();

$this->instance(
'summarize_collection',
Expand Down Expand Up @@ -149,11 +149,11 @@ public function test_makes_history_no_message(): void

LlmDriverFacade::shouldReceive('driver->chat')->never();

SearchAndSummarizeChatRepo::shouldReceive('search')->never();
RetrieveRelatedChatRepo::shouldReceive('search')->never();

$this->instance(
'search_and_summarize',
Mockery::mock(SearchAndSummarize::class, function ($mock) {
Mockery::mock(RetrieveRelated::class, function ($mock) {
$mock->shouldReceive('handle')
->once()
->andReturn(
Expand Down
4 changes: 2 additions & 2 deletions tests/Feature/ReportingToolTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class ReportingToolTest extends TestCase
*/
public function test_can_generate_function_as_array(): void
{
$searchAndSummarize = new \LlmLaraHub\LlmDriver\Functions\ReportingTool();
$RetrieveRelated = new \LlmLaraHub\LlmDriver\Functions\ReportingTool();

$function = $searchAndSummarize->getFunction();
$function = $RetrieveRelated->getFunction();

$parameters = $function->parameters;

Expand Down
6 changes: 3 additions & 3 deletions tests/Feature/SearchAndSummarizeChatRepoTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use App\Domains\Agents\VerifyPromptOutputDto;
use App\Domains\Chat\MetaDataDto;
use App\Domains\Messages\RoleEnum;
use App\Domains\Messages\SearchAndSummarizeChatRepo;
use App\Domains\Messages\RetrieveRelatedChatRepo;
use App\Models\Chat;
use App\Models\Collection;
use App\Models\Document;
Expand All @@ -16,7 +16,7 @@
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use Tests\TestCase;

class SearchAndSummarizeChatRepoTest extends TestCase
class RetrieveRelatedChatRepoTest extends TestCase
{
/**
* A basic feature test example.
Expand Down Expand Up @@ -84,7 +84,7 @@ public function test_can_search(): void
]),
]);

$results = (new SearchAndSummarizeChatRepo())->search($chat, $message);
$results = (new RetrieveRelatedChatRepo())->search($chat, $message);

$this->assertNotNull($results);
$this->assertDatabaseCount('message_document_references', 3);
Expand Down
10 changes: 5 additions & 5 deletions tests/Feature/SearchAndSummarizeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
use LlmLaraHub\LlmDriver\DistanceQuery\DistanceQueryFacade;
use LlmLaraHub\LlmDriver\Functions\ParametersDto;
use LlmLaraHub\LlmDriver\Functions\PropertyDto;
use LlmLaraHub\LlmDriver\Functions\SearchAndSummarize;
use LlmLaraHub\LlmDriver\Functions\RetrieveRelated;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use Tests\TestCase;

class SearchAndSummarizeTest extends TestCase
class RetrieveRelatedTest extends TestCase
{
/**
* A basic feature test example.
*/
public function test_can_generate_function_as_array(): void
{
$searchAndSummarize = new \LlmLaraHub\LlmDriver\Functions\SearchAndSummarize();
$RetrieveRelated = new \LlmLaraHub\LlmDriver\Functions\RetrieveRelated();

$function = $searchAndSummarize->getFunction();
$function = $RetrieveRelated->getFunction();

$parameters = $function->parameters;

Expand Down Expand Up @@ -117,7 +117,7 @@ public function test_gets_user_input()
));

$message = Message::factory()->user()->create(['chat_id' => $chat->id]);
$results = (new SearchAndSummarize())->handle($message);
$results = (new RetrieveRelated())->handle($message);

$this->assertNotNull($results);
$this->assertDatabaseCount('message_document_references', 3);
Expand Down
Loading

0 comments on commit 076d53b

Please sign in to comment.