Skip to content

Commit

Permalink
make the claude response dto smarter and add tools to it like ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Jul 31, 2024
1 parent e4bdc5c commit bdb7c52
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 8 deletions.
6 changes: 2 additions & 4 deletions Modules/LlmDriver/app/ClaudeClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\Functions\FunctionDto;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\ClaudeCompletionResponse;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;

Expand Down Expand Up @@ -69,12 +70,9 @@ public function chat(array $messages): CompletionResponse

[$data, $tool_used, $stop_reason] = $this->getContentAndToolTypeFromResults($results);

return CompletionResponse::from([
return ClaudeCompletionResponse::from([
'content' => $data,
'tool_used' => $tool_used,
'stop_reason' => $stop_reason,
'input_tokens' => data_get($results, 'usage.input_tokens', null),
'output_tokens' => data_get($results, 'usage.output_tokens', null),
]);
}

Expand Down
27 changes: 27 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeCompletionResponse.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Attributes\MapInputName;
use Spatie\LaravelData\Attributes\WithCastable;
use Spatie\LaravelData\Optional;

class ClaudeCompletionResponse extends CompletionResponse
{
public function __construct(
#[WithCastable(ClaudeContentCaster::class)]
public mixed $content,
public string $stop_reason = 'end_turn',
public ?string $tool_used = null,
/** @var array<ToolDto> */
#[WithCastable(ClaudeToolCaster::class)]
#[MapInputName('content')]
public array|Optional $tool_calls,
#[MapInputName('usage.input_tokens')]
public ?int $input_tokens = null,
#[MapInputName('usage.output_tokens')]
public ?int $output_tokens = null,
public ?string $model = null,
) {
}
}
33 changes: 33 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeContentCaster.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Pgvector\Laravel\Vector;
use Spatie\LaravelData\Casts\Cast;
use Spatie\LaravelData\Casts\Castable;
use Spatie\LaravelData\Support\Creation\CreationContext;
use Spatie\LaravelData\Support\DataProperty;

class ClaudeContentCaster implements Castable
{
public function __construct(public array $content)
{

}

public static function dataCastUsing(...$arguments): Cast
{
return new class implements Cast
{
public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): mixed
{
$results = collect($value)->filter(
function ($item) {
return $item['type'] === 'text';
}
)->first();
return data_get($results, 'text');
}
};
}
}
43 changes: 43 additions & 0 deletions Modules/LlmDriver/app/Responses/ClaudeToolCaster.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<?php

namespace LlmLaraHub\LlmDriver\Responses;

use Pgvector\Laravel\Vector;
use Spatie\LaravelData\Casts\Cast;
use Spatie\LaravelData\Casts\Castable;
use Spatie\LaravelData\Support\Creation\CreationContext;
use Spatie\LaravelData\Support\DataProperty;

class ClaudeToolCaster implements Castable
{
public function __construct(public array $tools)
{

}

public static function dataCastUsing(...$arguments): Cast
{
return new class implements Cast
{
public function cast(DataProperty $property, mixed $value, array $properties, CreationContext $context): mixed
{
$results = collect($value)->filter(
function ($item) {
return $item['type'] === 'tool_use';
}
)->toArray();

foreach($results as $index => $result) {
$results[$index] = ToolDto::from(
[
'name' => $result['name'],
'arguments' => $result['input'],
'id' => $result['id'],
]
);
}
return $results;
}
};
}
}
8 changes: 6 additions & 2 deletions Modules/LlmDriver/app/Responses/CompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

namespace LlmLaraHub\LlmDriver\Responses;


use Spatie\LaravelData\Attributes\MapInputName;
use Spatie\LaravelData\Optional;

class CompletionResponse extends \Spatie\LaravelData\Data
{
public function __construct(
public string $content,
public mixed $content,
public string $stop_reason = 'end_turn',
public ?string $tool_used = null,
/** @var array<ToolDto> */
public array $tool_calls = [],
public array|Optional $tool_calls,
public ?int $input_tokens = null,
public ?int $output_tokens = null,
public ?string $model = null,
Expand Down
5 changes: 3 additions & 2 deletions Modules/LlmDriver/app/Responses/OllamaCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

namespace LlmLaraHub\LlmDriver\Responses;

use Spatie\LaravelData\Optional;
use Spatie\LaravelData\Attributes\MapInputName;

class OllamaCompletionResponse extends CompletionResponse
{
public function __construct(
#[MapInputName('message.content')]
public string $content,
public mixed $content,
#[MapInputName('done_reason')]
public string $stop_reason = 'end_turn',
public ?string $tool_used = null,
/** @var array<OllamaToolDto> */
#[MapInputName('message.tool_calls')]
public array $tool_calls = [],
public array|Optional $tool_calls,
#[MapInputName('prompt_eval_count')]
public ?int $input_tokens = null,
#[MapInputName('eval_count')]
Expand Down
1 change: 1 addition & 0 deletions Modules/LlmDriver/app/Responses/ToolDto.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ToolDto extends Data
public function __construct(
public string $name,
public array $arguments,
public string $id = "",
) {
}

Expand Down
31 changes: 31 additions & 0 deletions tests/Feature/ClaudeCompletionResponseTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

namespace Feature;

use Illuminate\Foundation\Testing\RefreshDatabase;
use Illuminate\Foundation\Testing\WithFaker;
use Illuminate\Support\Arr;
use LlmLaraHub\LlmDriver\Responses\ClaudeCompletionResponse;
use LlmLaraHub\LlmDriver\Responses\ClaudeToolDto;
use LlmLaraHub\LlmDriver\Responses\OllamaCompletionResponse;
use LlmLaraHub\LlmDriver\Responses\OllamaToolDto;
use LlmLaraHub\LlmDriver\Responses\ToolDto;
use Tests\TestCase;

class ClaudeCompletionResponseTest extends TestCase
{
/**
* A basic feature test example.
*/
public function test_dto(): void
{
$results = get_fixture('cloud_client_tool_use_response.json');
$dto = ClaudeCompletionResponse::from($results);
$this->assertNotNull($dto->stop_reason);
$this->assertNotNull($dto->model);
$this->assertNotNull($dto->content);
$this->assertNotNull($dto->tool_calls);
$tool = Arr::first($dto->tool_calls);
$this->assertInstanceOf(ToolDto::class, $tool);
}
}

0 comments on commit bdb7c52

Please sign in to comment.