Skip to content

Commit

Permalink
this will start to make just chat work with tools then we can limit t…
Browse files Browse the repository at this point in the history
…he open area to the tools below
  • Loading branch information
alnutile committed Aug 4, 2024
1 parent 6d64d0d commit 4dbe5c1
Show file tree
Hide file tree
Showing 36 changed files with 1,740 additions and 180 deletions.
17 changes: 16 additions & 1 deletion Modules/LlmDriver/app/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@ abstract class BaseClient

protected ToolTypes $toolType;

protected bool $limitByShowInUi = false;

public function setToolType(ToolTypes $toolType): self
{
$this->toolType = $toolType;

return $this;
}

public function setLimitByShowInUi(bool $limitByShowInUi): self
{
$this->limitByShowInUi = $limitByShowInUi;

return $this;
}

public function setForceTool(FunctionDto $tool): self
{
$this->forceTool = $tool;
Expand All @@ -61,6 +70,7 @@ public function modifyPayload(array $payload, bool $noTools = false): array
$payload['tools'] = $this->getFunctions();
}

put_fixture('ollama_modified_payload.json', $payload);
return $payload;
}

Expand Down Expand Up @@ -219,12 +229,17 @@ public function getFunctions(): array
});
}

if ($this->limitByShowInUi) {
$functions = $functions->filter(function (FunctionContract $function) {
return $function->showInUi;
});
}

return $functions->transform(
function (FunctionContract $function) {
return $function->getFunction();
}
)->toArray();

}

public function remapFunctions(array $functions): array
Expand Down
2 changes: 2 additions & 0 deletions Modules/LlmDriver/app/Functions/Chat.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Chat extends FunctionContract
{
use ChatHelperTrait, ToolsHelper;

public bool $showInUi = false;

public array $toolTypes = [
ToolTypes::NoFunction,
];
Expand Down
2 changes: 1 addition & 1 deletion Modules/LlmDriver/app/Functions/CreateDocument.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CreateDocument extends FunctionContract

protected string $name = 'create_document';

protected string $description = 'Create a document in the collection of this local system';
protected string $description = 'Create or Save a document into the collection of this local system using the content provided';

public function handle(
Message $message): FunctionResponse
Expand Down
2 changes: 2 additions & 0 deletions Modules/LlmDriver/app/Functions/FunctionContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ abstract class FunctionContract
{
protected string $name;

public bool $showInUi = true;

public array $toolTypes = [
ToolTypes::Chat,
ToolTypes::ChatCompletion,
Expand Down
6 changes: 5 additions & 1 deletion Modules/LlmDriver/app/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

namespace LlmLaraHub\LlmDriver;

use LlmLaraHub\LlmDriver\Functions\FunctionDto;

class LlmDriverClient
{
protected $drivers = [];
Expand Down Expand Up @@ -42,7 +44,9 @@ protected function createDriver($name)
public function getFunctionsForUi(): array
{
return collect(
LlmDriverFacade::driver('mock')->getFunctions()
LlmDriverFacade::driver('mock')
->setLimitByShowInUi(true)
->getFunctions()
)
->map(function ($item) {
$item['name_formatted'] = str($item['name'])->headline()->toString();
Expand Down
35 changes: 19 additions & 16 deletions Modules/LlmDriver/app/OllamaClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
use Illuminate\Http\Client\Pool;
use Illuminate\Support\Facades\Http;
use Illuminate\Support\Facades\Log;
use Illuminate\Support\Str;
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
use LlmLaraHub\LlmDriver\Responses\OllamaChatCompletionResponse;
use LlmLaraHub\LlmDriver\Responses\OllamaCompletionResponse;

class OllamaClient extends BaseClient
Expand Down Expand Up @@ -93,7 +95,7 @@ public function functionPromptChat(array $messages, array $only = []): array
*/
public function chat(array $messages): CompletionResponse
{
Log::info('LlmDriver::OllamaClient::completion');
Log::info('LlmDriver::OllamaClient::chat');

$messages = $this->remapMessages($messages);

Expand All @@ -108,10 +110,6 @@ public function chat(array $messages): CompletionResponse

$payload = $this->modifyPayload($payload);

Log::info('LlmDriver::Ollama::chat', [
'payload' => $payload,
]);

$response = $this->getClient()->post('/chat', $payload);

if ($response->failed()) {
Expand All @@ -121,7 +119,7 @@ public function chat(array $messages): CompletionResponse
throw new \Exception('Ollama API Error Chat');
}

return OllamaCompletionResponse::from($response->json());
return OllamaChatCompletionResponse::from($response->json());
}

/**
Expand Down Expand Up @@ -194,6 +192,14 @@ public function completion(string $prompt): CompletionResponse
'stream' => false,
]);

if ($response->failed()) {
Log::error('Ollama Completion API Error ', [
'error' => $response->body(),
]);
throw new \Exception('Ollama API Error Completion');
}

put_fixture('ollama_completion.json', $response->json());
return OllamaCompletionResponse::from($response->json());
}

Expand Down Expand Up @@ -222,7 +228,7 @@ public function getFunctions(): array
{
$functions = parent::getFunctions();

return collect($functions)->map(function ($function) {
$results = collect($functions)->map(function ($function) {
$properties = [];
$required = [];

Expand Down Expand Up @@ -254,7 +260,11 @@ public function getFunctions(): array

];

})->toArray();
})->values()->toArray();

put_fixture('ollama_functions.json', $results);

return $results;
}

public function isAsync(): bool
Expand All @@ -278,13 +288,6 @@ public function remapMessages(array $messages): array
->toArray();
})->toArray();

if (in_array($this->getConfig('ollama')['models']['completion_model'], ['llama3.1', 'llama3'])) {
Log::info('[LaraChain] LlmDriver::OllamaClient::remapMessages');
$messages = collect($messages)->reverse();
}

put_fixture('ollama_messages_after_remap.json', $messages->values()->toArray());

return $messages->values()->toArray();
return $messages;
}
}
26 changes: 26 additions & 0 deletions Modules/LlmDriver/app/Responses/OllamaChatCompletionResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Attributes\MapInputName;
use Spatie\LaravelData\Optional;

class OllamaChatCompletionResponse extends CompletionResponse
{
public function __construct(
#[MapInputName('message.content')]
public mixed $content,
#[MapInputName('done_reason')]
public string|Optional $stop_reason,
public ?string $tool_used = '',
/** @var array<OllamaToolDto> */
#[MapInputName('message.tool_calls')]
public array $tool_calls = [],
#[MapInputName('prompt_eval_count')]
public ?int $input_tokens = null,
#[MapInputName('eval_count')]
public ?int $output_tokens = null,
public ?string $model = null,
) {
}
}
4 changes: 2 additions & 2 deletions Modules/LlmDriver/app/Responses/OllamaCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
class OllamaCompletionResponse extends CompletionResponse
{
public function __construct(
#[MapInputName('message.content')]
#[MapInputName('response')]
public mixed $content,
#[MapInputName('done_reason')]
public string|Optional $stop_reason,
public ?string $tool_used = '',
/** @var array<OllamaToolDto> */
#[MapInputName('message.tool_calls')]
#[MapInputName('tool_calls')]
public array $tool_calls = [],
#[MapInputName('prompt_eval_count')]
public ?int $input_tokens = null,
Expand Down
41 changes: 23 additions & 18 deletions app/Domains/Orchestration/OrchestrateVersionTwo.php
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public function handle(
Chat $chat,
Message $message)
{
$toolType = ToolTypes::ChatCompletion;

//OpenAI
//@see https://platform.openai.com/docs/guides/function-calling
Expand Down Expand Up @@ -78,28 +79,32 @@ public function handle(
* Here the user is just forcing a chat
* they want to continue with the thread
*/
if ($message->meta_data?->tool === 'chat') {
Log::info('[LaraChain] - Just Chatting');
$this->justChat($chat, $message, ToolTypes::Chat);
} else {
$messages = $chat->getChatResponse();
if ($message->meta_data?->tool === ToolTypes::Chat->value) {
$toolType = ToolTypes::Chat;

put_fixture('orchestrate_messages_fist_send.json', $messages);
Log::info('[LaraChain] - Setting it as a chat tool scope', [
'tool_type' => $toolType,
]);
}

Log::info('[LaraChain] - Looking for Tools');
$response = LlmDriverFacade::driver($message->getDriver())
->setToolType(ToolTypes::ChatCompletion)
->chat($messages);
$messages = $chat->getChatResponse();

if (! empty($response->tool_calls)) {
Log::info('[LaraChain] - Tools Found');
$this->chatWithTools($chat, $message, $response);
put_fixture('orchestrate_messages_first_send.json', $messages);

} else {
//hmm
Log::info('[LaraChain] - No Tools found just gonna chat');
$this->justChat($chat, $message, ToolTypes::NoFunction);
}
Log::info('[LaraChain] - Looking for Tools');

$response = LlmDriverFacade::driver($message->getDriver())
->setToolType($toolType)
->chat($messages);

if (! empty($response->tool_calls)) {
Log::info('[LaraChain] - Tools Found');
$this->chatWithTools($chat, $message, $response);

} else {
//hmm
Log::info('[LaraChain] - No Tools found just gonna chat');
$this->justChat($chat, $message, ToolTypes::NoFunction);
}
}

Expand Down
25 changes: 0 additions & 25 deletions app/Jobs/SummarizeDocumentJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public function handle(): void
} else {
$prompt = Templatizer::appendContext(true)
->handle($this->prompt, $content);

}

/** @var CompletionResponse $results */
Expand All @@ -81,30 +80,6 @@ public function handle(): void

$this->results = $results->content;

if (Feature::active('verification_prompt_summary')) {

$verifyPrompt = <<<'PROMPT'
This the content from all the documents in this collection.
Then that was passed into the LLM to summarize the results.
PROMPT;

$dto = VerifyPromptInputDto::from(
[
'chattable' => $this->document->collection,
'originalPrompt' => $prompt,
'context' => $content,
'llmResponse' => $this->results,
'verifyPrompt' => $verifyPrompt,
]
);

/** @var VerifyPromptOutputDto $response */
$response = VerifyResponseAgent::verify($dto);

$this->results = $response->response;

}

$this->document->update([
'summary' => $this->results,
'status_summary' => StatusEnum::SummaryComplete,
Expand Down
1 change: 1 addition & 0 deletions app/Models/Document.php
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ public static function make(
'collection_id' => $collection->id,
'type' => TypesEnum::Txt,
'subject' => str($content)->limit(256)->toString(),
'summary' => $content,
'original_content' => $content,
'status_summary' => StatusEnum::Pending,
]);
Expand Down
Loading

0 comments on commit 4dbe5c1

Please sign in to comment.