Skip to content

Commit

Permalink
slowly moving in tools
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Aug 10, 2024
1 parent 321641f commit 188c91c
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 29 deletions.
10 changes: 8 additions & 2 deletions Modules/LlmDriver/app/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ public function chat(array $messages): CompletionResponse

$data = fake()->sentences(3, true);

return new CompletionResponse($data);
return CompletionResponse::from([
'content' => $data,
'stop_reason' => 'stop',
]);
}

public function completion(string $prompt): CompletionResponse
Expand All @@ -135,7 +138,10 @@ public function completion(string $prompt): CompletionResponse
Voluptate irure cillum dolor anim officia reprehenderit dolor. Eiusmod veniam nostrud consectetur incididunt proident id. Anim adipisicing pariatur amet duis Lorem sunt veniam veniam est. Deserunt ea aliquip cillum pariatur consectetur. Dolor in reprehenderit adipisicing consectetur cupidatat ad cupidatat reprehenderit. Nostrud mollit voluptate aliqua anim pariatur excepteur eiusmod velit quis exercitation tempor quis excepteur.
EOD;

return new CompletionResponse($data);
return CompletionResponse::from([
'content' => $data,
'stop_reason' => 'stop',
]);
}

protected function getConfig(string $driver): array
Expand Down
20 changes: 2 additions & 18 deletions Modules/LlmDriver/app/ClaudeClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\Functions\FunctionDto;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\ClaudeCompletionResponse;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;

Expand All @@ -37,25 +38,16 @@ public function chat(array $messages): CompletionResponse

Log::info('LlmDriver::Claude::chat');

/**
* I need to iterate over each item
* then if there are two rows with role assistant I need to insert
* in between a user row with some copy to make it work like "And the user search results had"
* using the Laravel Collection library
*/
$messages = $this->remapMessages($messages);

$payload = [
'model' => $model,
'system' => 'Return a markdown response.',
'max_tokens' => $maxTokens,
'messages' => $messages,
];

$payload = $this->modifyPayload($payload);

put_fixture('claude_payload_chat.json', $payload);

$results = $this->getClient()->post('/messages', $payload);

if (! $results->ok()) {
Expand All @@ -69,15 +61,7 @@ public function chat(array $messages): CompletionResponse
throw new \Exception('Claude API Error Chat');
}

[$data, $tool_used, $stop_reason] = $this->getContentAndToolTypeFromResults($results);

return CompletionResponse::from([
'content' => $data,
'tool_used' => $tool_used,
'stop_reason' => $stop_reason,
'input_tokens' => data_get($results, 'usage.input_tokens', null),
'output_tokens' => data_get($results, 'usage.output_tokens', null),
]);
return ClaudeCompletionResponse::from($results->json());
}

public function completion(string $prompt): CompletionResponse
Expand Down
2 changes: 0 additions & 2 deletions Modules/LlmDriver/app/GroqClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ public function embedData(string $data): EmbeddingsResponseDto
public function chat(array $messages): CompletionResponse
{
$model = $this->getConfig('groq')['models']['completion_model'];
$maxTokens = $this->getConfig('groq')['max_tokens'];

Log::info('LlmDriver::Groq::chat');

$messages = $this->remapMessages($messages);

$results = $this->getClient()->post('/chat/completions', [
'model' => $model,
'max_tokens' => $maxTokens,
'messages' => $this->messagesToArray($messages),
]);

Expand Down
13 changes: 11 additions & 2 deletions Modules/LlmDriver/app/OllamaClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ public function chat(array $messages): CompletionResponse
throw new \Exception('Ollama API Error');
}

return new CompletionResponse($response->json()['message']['content']);
return CompletionResponse::from([
'content' => $response->json()['message']['content'],
'stop_reason' => 'stop',
]);
}

/**
Expand Down Expand Up @@ -186,9 +189,15 @@ public function completion(string $prompt): CompletionResponse
'stream' => false,
]);

/**
* @see https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
*/
$results = $response->json()['response'];

return new CompletionResponse($results);
return CompletionResponse::from([
'content' => $results,
'stop_reason' => 'stop',
]);
}

protected function getClient()
Expand Down
27 changes: 27 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeCompletionResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Attributes\MapInputName;
use Spatie\LaravelData\Attributes\WithCastable;
use Spatie\LaravelData\Optional;

class ClaudeCompletionResponse extends CompletionResponse
{
public function __construct(
#[WithCastable(ClaudeContentCaster::class)]
public mixed $content,
public string|Optional $stop_reason,
public ?string $tool_used = '',
/** @var array<ToolDto> */
#[WithCastable(ClaudeToolCaster::class)]
#[MapInputName('content')]
public array $tool_calls = [],
#[MapInputName('usage.input_tokens')]
public ?int $input_tokens = null,
#[MapInputName('usage.output_tokens')]
public ?int $output_tokens = null,
public ?string $model = null,
) {
}
}
33 changes: 33 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeContentCaster.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Casts\Cast;
use Spatie\LaravelData\Casts\Castable;
use Spatie\LaravelData\Support\Creation\CreationContext;
use Spatie\LaravelData\Support\DataProperty;

class ClaudeContentCaster implements Castable
{
public function __construct(public array $content)
{

}

public static function dataCastUsing(...$arguments): Cast
{
return new class implements Cast
{
public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): mixed
{
$results = collect($value)->filter(
function ($item) {
return $item['type'] === 'text';
}
)->first();

return data_get($results, 'text');
}
};
}
}
43 changes: 43 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeToolCaster.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Casts\Cast;
use Spatie\LaravelData\Casts\Castable;
use Spatie\LaravelData\Support\Creation\CreationContext;
use Spatie\LaravelData\Support\DataProperty;

class ClaudeToolCaster implements Castable
{
public function __construct(public array $tools)
{

}

public static function dataCastUsing(...$arguments): Cast
{
return new class implements Cast
{
public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): mixed
{
$results = collect($value)->filter(
function ($item) {
return $item['type'] === 'tool_use';
}
)->toArray();

foreach ($results as $index => $result) {
$results[$index] = ToolDto::from(
[
'name' => $result['name'],
'arguments' => $result['input'],
'id' => $result['id'],
]
);
}

return $results;
}
};
}
}
11 changes: 8 additions & 3 deletions Modules/LlmDriver/app/Responses/CompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Optional;

class CompletionResponse extends \Spatie\LaravelData\Data
{
public function __construct(
public string $content,
public string $stop_reason = 'end_turn',
public ?string $tool_used = null,
public mixed $content,
public string|Optional $stop_reason,
public ?string $tool_used = '',
/** @var array<ToolDto> */
public array $tool_calls = [],
public ?int $input_tokens = null,
public ?int $output_tokens = null,
public ?string $model = null,
) {
}
}
16 changes: 16 additions & 0 deletions Modules/LlmDriver/app/Responses/ToolDto.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Data;
use Spatie\LaravelData\Optional;

class ToolDto extends Data
{
public function __construct(
public string $name,
public array $arguments,
public string|Optional $id = '',
) {
}
}
2 changes: 0 additions & 2 deletions Modules/LlmDriver/tests/Feature/MockClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ public function test_completion(): void
$client = new MockClient();

$results = $client->completion('test');

$this->assertInstanceOf(CompletionResponse::class, $results);

}

public function test_Chat(): void
Expand Down
38 changes: 38 additions & 0 deletions tests/Api/ClientsTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
<?php

namespace Tests\Api;

use App\Domains\Messages\RoleEnum;
use App\Models\Setting;
use App\Models\Source;
use Illuminate\Support\Arr;
use Illuminate\Support\Facades\Http;
use LlmLaraHub\LlmDriver\LlmDriverFacade;
use App\Domains\EmailParser\MailDto;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use Tests\TestCase;

class ClientsTest extends TestCase
{


public function test_clients() {
$this->markTestSkipped('@TODO will setup the tokens shortly');
$prompt = <<<PROMPT
What do you know about Laravel
PROMPT;

$messages = [];
$messages[] = MessageInDto::from([
"role" => RoleEnum::User->value,
"content" => $prompt
]);
$results = LlmDriverFacade::driver("groq")->chat($messages);
$this->assertNotNull($results->content);
$results = LlmDriverFacade::driver("openai")->chat($messages);
$this->assertNotNull($results->content);
$results = LlmDriverFacade::driver("claude")->chat($messages);
$this->assertNotNull($results->content);

}
}

0 comments on commit 188c91c

Please sign in to comment.