Skip to content

Commit

Permalink
Tons of QA to go
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Jul 5, 2024
1 parent 0a54b95 commit 85d8823
Show file tree
Hide file tree
Showing 14 changed files with 188 additions and 64 deletions.
32 changes: 20 additions & 12 deletions Modules/LlmDriver/app/ClaudeClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ public function chat(array $messages): CompletionResponse
/**
* I need to iterate over each item
* then if there are two rows with role assistant I need to insert
* in betwee a user row with some copy to make it work like "And the user search results had"
* in between a user row with some copy to make it work like "And the user search results had"
* using the Laravel Collection library
*/
$messages = $this->remapMessages($messages);

put_fixture('after_mapping.json', $messages);

$results = $this->getClient()->post('/messages', [
'model' => $model,
'system' => 'Return a markdown response.',
Expand All @@ -54,12 +52,12 @@ public function chat(array $messages): CompletionResponse
if (! $results->ok()) {
$error = $results->json()['error']['type'];
$message = $results->json()['error']['message'];
put_fixture('claude_error.json', $results->json());
Log::error('Claude API Error ', [
Log::error('Claude API Error Chat', [
'type' => $error,
'message' => $message,
]);
throw new \Exception('Claude API Error '.$message);

throw new \Exception('Claude API Error Chat');
}

$data = null;
Expand Down Expand Up @@ -93,8 +91,13 @@ public function completion(string $prompt): CompletionResponse

if (! $results->ok()) {
$error = $results->json()['error']['type'];
Log::error('Claude API Error '.$error);
throw new \Exception('Claude API Error '.$error);
$message = $results->json()['error']['message'];
Log::error('Claude API Error Chat', [
'type' => $error,
'message' => $message,
]);

throw new \Exception('Claude API Error Chat');
}

$data = null;
Expand Down Expand Up @@ -201,10 +204,10 @@ protected function getClient()
*/
public function functionPromptChat(array $messages, array $only = []): array
{

$messages = $this->remapMessages($messages);
Log::info('LlmDriver::ClaudeClient::functionPromptChat', $messages);

$functions = $this->getFunctions();
Log::info('LlmDriver::ClaudeClient::functionPromptChat', $messages);

$model = $this->getConfig('claude')['models']['completion_model'];
$maxTokens = $this->getConfig('claude')['max_tokens'];
Expand All @@ -222,11 +225,13 @@ public function functionPromptChat(array $messages, array $only = []): array
if (! $results->ok()) {
$error = $results->json()['error']['type'];
$message = $results->json()['error']['message'];
Log::error('Claude API Error ', [

Log::error('Claude API Error getting functions ', [
'type' => $error,
'message' => $message,
]);
throw new \Exception('Claude API Error '.$message);

throw new \Exception('Claude API Error getting functions');
}

$stop_reason = $results->json()['stop_reason'];
Expand Down Expand Up @@ -302,6 +307,7 @@ public function getFunctions(): array
*/
protected function remapMessages(array $messages): array
{
put_fixture('before_mapping.json', $messages);
$messages = collect($messages)->map(function ($item) {
if ($item->role === 'system') {
$item->role = 'assistant';
Expand Down Expand Up @@ -342,6 +348,8 @@ protected function remapMessages(array $messages): array

}

put_fixture('after_mapping.json', $newMessagesArray);

return $newMessagesArray;
}

Expand Down
73 changes: 45 additions & 28 deletions Modules/LlmDriver/app/Orchestrate.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

namespace LlmLaraHub\LlmDriver;

use App\Domains\Chat\ToolsDto;
use App\Domains\Messages\RoleEnum;
use App\Models\Chat;
use App\Models\Filter;
use App\Models\Message;
use App\Models\PromptHistory;
use Facades\App\Domains\Messages\SearchAndSummarizeChatRepo;
use Facades\LlmLaraHub\LlmDriver\Functions\StandardsChecker;
use Illuminate\Support\Arr;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Functions\FunctionCallDto;
use LlmLaraHub\LlmDriver\Helpers\CreateReferencesTrait;
Expand Down Expand Up @@ -43,10 +43,15 @@ public function handle(
* after this refactor
*/
$messagesArray = $message->getLatestMessages();

put_fixture('latest_messages.json', $messagesArray);

$filter = $message->meta_data->filter;

if ($filter) {
$filter = Filter::find($filter);
}

$tool = $message->meta_data->tool;

if ($tool) {
Expand All @@ -73,11 +78,13 @@ public function handle(
'filter' => $filter,
]);

$this->addToolsToMessage($message, $functionDto);

$response = StandardsChecker::handle($message);
$messagesArray = $this->handleResponse($response, $chat);
$this->handleResponse($response, $chat, $message);
$this->response = $response->content;
$this->requiresFollowup = $response->requires_follow_up_prompt;
$this->requiresFollowUp($messagesArray, $chat);
$this->requiresFollowUp($message->getLatestMessages(), $chat);
}

} else {
Expand Down Expand Up @@ -123,6 +130,8 @@ public function handle(
'filter' => $filter,
]);

$this->addToolsToMessage($message, $functionDto);

/** @var FunctionResponse $response */
$response = $functionClass->handle($message);

Expand All @@ -137,7 +146,8 @@ public function handle(
message: $response->content,
role: RoleEnum::Assistant,
show_in_thread: true,
meta_data: $message->meta_data);
meta_data: $message->meta_data,
tools: $message->tools);
}

if ($response->prompt) {
Expand Down Expand Up @@ -188,40 +198,47 @@ protected function hasFunctions(array $functions): bool
/**
* @return MessageInDto[]
*/
protected function handleResponse(FunctionResponse $response, Chat $chat): array
protected function handleResponse(
FunctionResponse $response,
Chat $chat,
Message $message): void
{
$message = null;

if ($response->save_to_message) {
$message = $chat->addInput(
message: $response->content,
role: RoleEnum::Assistant,
show_in_thread: true);
}
show_in_thread: true,
meta_data: $message->meta_data,
tools: $message->tools);

if ($response->prompt) {
PromptHistory::create([
'prompt' => $response->prompt,
'chat_id' => $chat->getChat()->id,
'message_id' => $message?->id,
/** @phpstan-ignore-next-line */
'collection_id' => $chat->getChatable()?->id,
]);
}

if ($response->prompt) {
PromptHistory::create([
'prompt' => $response->prompt,
'chat_id' => $chat->getChat()->id,
'message_id' => $message?->id,
/** @phpstan-ignore-next-line */
'collection_id' => $chat->getChatable()?->id,
]);
if (! empty($response->documentChunks)) {
$this->saveDocumentReference(
$message,
$response->documentChunks
);
}
}
}

if (! empty($response->documentChunks)) {
$this->saveDocumentReference(
$message,
$response->documentChunks
);
protected function addToolsToMessage(Message $message, FunctionCallDto $functionDto): void
{
$tools = $message->tools;
if (! $tools) {
$tools = ToolsDto::from(['tools' => []]);
}

$messagesArray = Arr::wrap(MessageInDto::from([
'role' => 'assistant',
'content' => $response->content,
]));

return $messagesArray;
$tools->tools[] = $functionDto;
$message->updateQuietly(['tools' => $tools]);
}

protected function requiresFollowUp(array $messagesArray, Chat $chat): void
Expand Down
4 changes: 1 addition & 3 deletions app/Domains/Chat/ToolsDto.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

class ToolsDto extends Data
{

/**
* @param FunctionCallDto[] $tools
* @param FunctionCallDto[] $tools
*/
public function __construct(
public array $tools = []
) {
}

}
2 changes: 2 additions & 0 deletions app/Http/Resources/MessageDocumentReferenceResource.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MessageDocumentReferenceResource extends JsonResource
public function toArray(Request $request): array
{
$tags = $this->document_chunk?->tags;

if ($tags) {
$tags = TagResource::collection($tags);
}
Expand All @@ -28,6 +29,7 @@ public function toArray(Request $request): array
'section_number' => $this->document_chunk?->section_number + 1, //since 0 does not look good in the ui
'summary' => str($this->document_chunk?->content)->markdown(),
'taggings' => $tags,
'type' => $this->document_chunk?->document->type,
];
}
}
1 change: 0 additions & 1 deletion app/Listeners/AddChatTitleListener.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ public function __construct()
public function handle(MessageCreatedEvent $event): void
{
TitleRepo::handle($event->message);
notify_ui($event->message->chat, 'Chat Title Updated');
}
}
8 changes: 5 additions & 3 deletions app/Models/Chat.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace App\Models;

use App\Domains\Chat\MetaDataDto;
use App\Domains\Chat\ToolsDto;
use App\Domains\Messages\RoleEnum;
use Illuminate\Database\Eloquent\Factories\HasFactory;
use Illuminate\Database\Eloquent\Model;
Expand Down Expand Up @@ -87,13 +88,14 @@ public function addInput(string $message,
RoleEnum $role = RoleEnum::User,
?string $systemPrompt = null,
bool $show_in_thread = true,
?MetaDataDto $meta_data = null): Message
?MetaDataDto $meta_data = null,
?ToolsDto $tools = null): Message
{
if (! $meta_data) {
$meta_data = MetaDataDto::from([]);
}

return DB::transaction(function () use ($message, $role, $systemPrompt, $show_in_thread, $meta_data) {
return DB::transaction(function () use ($message, $role, $tools, $systemPrompt, $show_in_thread, $meta_data) {

if ($systemPrompt) {
$this->createSystemMessageIfNeeded($systemPrompt);
Expand All @@ -109,6 +111,7 @@ public function addInput(string $message,
'chat_id' => $this->id,
'is_chat_ignored' => ! $show_in_thread,
'meta_data' => $meta_data,
'tools' => $tools,
]);
});

Expand Down Expand Up @@ -144,7 +147,6 @@ public function getChatResponse(int $limit = 5): array
{
$latestMessages = $this->messages()
->orderBy('id', 'desc')
->limit(5)
->get();

$latestMessagesArray = [];
Expand Down
1 change: 0 additions & 1 deletion app/Models/Message.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
use Illuminate\Database\Eloquent\Model;
use Illuminate\Database\Eloquent\Relations\BelongsTo;
use Illuminate\Database\Eloquent\Relations\HasMany;
use LlmLaraHub\LlmDriver\Functions\FunctionCallDto;
use LlmLaraHub\LlmDriver\HasDrivers;

class Message extends Model implements HasDrivers
Expand Down
9 changes: 8 additions & 1 deletion resources/js/Pages/Chat/Components/ReferenceTable.vue
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@
<tr v-for="reference in message.message_document_references" :key="reference.id">
<th>{{ reference.id }}</th>
<td>
<a class="underline" :href="route('download.document', {
<a
v-if="reference.type === 'html'"
target="_blank"
class="underline"
:href="reference.document_name">{{ reference.document_name }}</a>
<a
v-else
class="underline" :href="route('download.document', {
collection: message.collection_id,
document_name: reference.document_name
})">{{ reference.document_name }}</a>
Expand Down
1 change: 0 additions & 1 deletion tests/Feature/Models/MessageTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public function test_factory(): void
$this->assertNotNull($model->tools);
$this->assertInstanceOf(ToolsDto::class, $model->tools);


}

public function test_get_filter(): void
Expand Down
4 changes: 4 additions & 0 deletions tests/Feature/OrchestrateTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public function test_gets_summarize_function(): void
]);

$message = Message::factory()->user()->create([
'tools' => [],
'meta_data' => MetaDataDto::from(
[
'tool' => '',
Expand All @@ -76,7 +77,10 @@ public function test_gets_summarize_function(): void
Event::assertDispatched(ChatUiUpdateEvent::class);

$this->assertEquals($results, 'This is the summary of the collection');

$this->assertDatabaseCount('prompt_histories', 1);

$this->assertCount(1, $message->tools->tools);
}

public function test_tool_standards_checker(): void
Expand Down
Loading

0 comments on commit 85d8823

Please sign in to comment.