Skip to content

Commit

Permalink
Merge pull request #37 from LlmLaraHub/next_message_deeper_into_funct…
Browse files Browse the repository at this point in the history
…ions

Next move Message deeper into functions
  • Loading branch information
alnutile authored Jul 6, 2024
2 parents 39dce4a + e82e8a3 commit a7afaf9
Show file tree
Hide file tree
Showing 57 changed files with 1,086 additions and 708 deletions.
53 changes: 39 additions & 14 deletions Modules/LlmDriver/app/ClaudeClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use App\Models\Setting;
use Illuminate\Http\Client\Pool;
use Illuminate\Http\Client\Response;
use Illuminate\Support\Arr;
use Illuminate\Support\Facades\Http;
use Illuminate\Support\Facades\Log;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
Expand Down Expand Up @@ -37,13 +38,11 @@ public function chat(array $messages): CompletionResponse
/**
* I need to iterate over each item
* then if there are two rows with role assistant I need to insert
* in betwee a user row with some copy to make it work like "And the user search results had"
* in between a user row with some copy to make it work like "And the user search results had"
* using the Laravel Collection library
*/
$messages = $this->remapMessages($messages);

put_fixture('after_mapping.json', $messages);

$results = $this->getClient()->post('/messages', [
'model' => $model,
'system' => 'Return a markdown response.',
Expand All @@ -54,12 +53,12 @@ public function chat(array $messages): CompletionResponse
if (! $results->ok()) {
$error = $results->json()['error']['type'];
$message = $results->json()['error']['message'];
put_fixture('claude_error.json', $results->json());
Log::error('Claude API Error ', [
Log::error('Claude API Error Chat', [
'type' => $error,
'message' => $message,
]);
throw new \Exception('Claude API Error '.$message);

throw new \Exception('Claude API Error Chat');
}

$data = null;
Expand Down Expand Up @@ -93,8 +92,13 @@ public function completion(string $prompt): CompletionResponse

if (! $results->ok()) {
$error = $results->json()['error']['type'];
Log::error('Claude API Error '.$error);
throw new \Exception('Claude API Error '.$error);
$message = $results->json()['error']['message'];
Log::error('Claude API Error Chat', [
'type' => $error,
'message' => $message,
]);

throw new \Exception('Claude API Error Chat');
}

$data = null;
Expand Down Expand Up @@ -201,10 +205,15 @@ protected function getClient()
*/
public function functionPromptChat(array $messages, array $only = []): array
{
$messages = $this->remapMessages($messages);
Log::info('LlmDriver::ClaudeClient::functionPromptChat', $messages);

$functions = $this->getFunctions();
$messages = $this->remapMessages($messages, true);

/**
* @NOTE
* The api will not let me end this array in an assistant message
* it has to end in a user message
*/
Log::info('LlmDriver::ClaudeClient::functionPromptChat', $messages);

$model = $this->getConfig('claude')['models']['completion_model'];
$maxTokens = $this->getConfig('claude')['max_tokens'];
Expand All @@ -222,11 +231,13 @@ public function functionPromptChat(array $messages, array $only = []): array
if (! $results->ok()) {
$error = $results->json()['error']['type'];
$message = $results->json()['error']['message'];
Log::error('Claude API Error ', [

Log::error('Claude API Error getting functions ', [
'type' => $error,
'message' => $message,
]);
throw new \Exception('Claude API Error '.$message);

throw new \Exception('Claude API Error getting functions');
}

$stop_reason = $results->json()['stop_reason'];
Expand Down Expand Up @@ -300,8 +311,9 @@ public function getFunctions(): array
*
* @param MessageInDto[] $messages
*/
protected function remapMessages(array $messages): array
protected function remapMessages(array $messages, bool $userLast = false): array
{
put_fixture('before_mapping.json', $messages);
$messages = collect($messages)->map(function ($item) {
if ($item->role === 'system') {
$item->role = 'assistant';
Expand Down Expand Up @@ -342,6 +354,19 @@ protected function remapMessages(array $messages): array

}

if ($userLast) {
$last = Arr::last($newMessagesArray);

if ($last['role'] === 'assistant') {
$newMessagesArray[] = [
'role' => 'user',
'content' => 'Using the surrounding context to continue this response thread',
];
}
}

put_fixture('after_mapping.json', $newMessagesArray);

return $newMessagesArray;
}

Expand Down
2 changes: 0 additions & 2 deletions Modules/LlmDriver/app/DistanceQuery/Drivers/Base.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use App\Domains\Chat\MetaDataDto;
use App\Models\DocumentChunk;
use App\Models\Filter;
use Illuminate\Support\Collection;
use Pgvector\Laravel\Vector;

Expand All @@ -14,7 +13,6 @@ abstract public function cosineDistance(
string $embeddingSize,
int $collectionId,
Vector $embedding,
?Filter $filter = null,
?MetaDataDto $meta_data = null
): Collection;

Expand Down
2 changes: 0 additions & 2 deletions Modules/LlmDriver/app/DistanceQuery/Drivers/Mock.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
use App\Domains\Chat\MetaDataDto;
use App\Models\Collection as CollectionModel;
use App\Models\DocumentChunk;
use App\Models\Filter;
use Illuminate\Support\Collection;
use Pgvector\Laravel\Vector;

Expand All @@ -15,7 +14,6 @@ public function cosineDistance(
string $embeddingSize,
int $collectionId,
Vector $embedding,
?Filter $filter = null,
?MetaDataDto $meta_data = null
): Collection {
$documents = CollectionModel::find($collectionId)->documents->pluck('id');
Expand Down
24 changes: 13 additions & 11 deletions Modules/LlmDriver/app/DistanceQuery/Drivers/PostGres.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use App\Domains\Chat\MetaDataDto;
use App\Models\Document;
use App\Models\DocumentChunk;
use App\Models\Filter;
use Illuminate\Support\Collection;
use Illuminate\Support\Facades\Log;
use Pgvector\Laravel\Distance;
Expand All @@ -18,10 +17,12 @@ public function cosineDistance(
string $embeddingSize,
int $collectionId,
Vector $embedding,
?Filter $filter = null,
?MetaDataDto $meta_data = null
): Collection {

$filter = $meta_data?->getFilter();
$date_range = $meta_data?->date_range;

Log::info('[LaraChain] - PostgresSQL Cosine Query', [
'filter' => $filter?->toArray(),
'embedding_size' => $embeddingSize,
Expand All @@ -32,16 +33,17 @@ public function cosineDistance(
->when($filter, function ($query, $filter) {
$query->whereIn('id', $filter->documents()->pluck('id'));
})
->when($meta_data, function ($query, $meta_data) {
if ($meta_data->date_range) {
$results = DateRangesEnum::getStartAndEndDates($meta_data->date_range);
->when($date_range, function ($query, $date_range) {
Log::info('Date Range', [
'date_range' => $date_range,
]);
$results = DateRangesEnum::getStartAndEndDates($date_range);

$query->whereBetween(
'created_at', [
$results['start'],
$results['end'],
]);
}
$query->whereBetween(
'created_at', [
$results['start'],
$results['end'],
]);
})
->where('documents.collection_id', $collectionId)
->orderBy('id')
Expand Down
5 changes: 5 additions & 0 deletions Modules/LlmDriver/app/Functions/ArgumentCaster.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ class ArgumentCaster implements Cast
{
public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): array
{

if (is_array($value)) {
return $value;
}

return json_decode($value, true);
}
}
11 changes: 3 additions & 8 deletions Modules/LlmDriver/app/Functions/FunctionContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

namespace LlmLaraHub\LlmDriver\Functions;

use LlmLaraHub\LlmDriver\HasDrivers;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use App\Models\Message;
use LlmLaraHub\LlmDriver\Responses\FunctionResponse;

abstract class FunctionContract
Expand All @@ -14,13 +13,9 @@ abstract class FunctionContract

protected string $type = 'object';

/**
* @param MessageInDto[] $messageArray
*/
abstract public function handle(
array $messageArray,
HasDrivers $model,
FunctionCallDto $functionCallDto): FunctionResponse;
Message $message,
): FunctionResponse;

public function getFunction(): FunctionDto
{
Expand Down
Loading

0 comments on commit a7afaf9

Please sign in to comment.