Skip to content

Commit

Permalink
add function helper
Browse files Browse the repository at this point in the history
  • Loading branch information
alnutile committed Apr 2, 2024
1 parent 351f6ef commit abfdfcb
Show file tree
Hide file tree
Showing 14 changed files with 174 additions and 55 deletions.
40 changes: 38 additions & 2 deletions app/LlmDriver/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,42 @@ public function embedData(string $data): EmbeddingsResponseDto
]);
}

/**
* This is to get functions out of the llm
* if none are returned your system
* can error out or try another way.
*
* @param MessageInDto[] $messages
*/
public function functionPromptChat(array $messages, array $only = []): array
{
if (! app()->environment('testing')) {
sleep(2);
}

Log::info('LlmDriver::MockClient::functionPromptChat', $messages);

$data = get_fixture('openai_response_with_functions.json');

$functions = [];

foreach (data_get($data, 'choices', []) as $choice) {
foreach (data_get($choice, 'message.toolCalls', []) as $tool) {
if (data_get($tool, 'type') === 'function') {
$name = data_get($tool, 'function.name', null);
if (! in_array($name, $only)) {
$functions[] = [
'name' => $name,
'arguments' => json_decode(data_get($tool, 'function.arguments', []), true),
];
}
}
}
}

return $functions;
}

/**
* @param MessageInDto[] $messages
*/
Expand Down Expand Up @@ -64,8 +100,8 @@ protected function getConfig(string $driver): array
return config("llmdriver.drivers.$driver");
}


public function getFunctions() : array {
public function getFunctions(): array
{
return [];
}
}
20 changes: 9 additions & 11 deletions app/LlmDriver/Functions/FunctionContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

namespace App\LlmDriver\Functions;

use App\LlmDriver\Functions\PropertyDto;

abstract class FunctionContract
{
protected string $name;

protected string $dscription;
protected string $description;

protected string $type = 'object';

/**
* @param array<string, mixed> $data
* @return array<string, mixed>
*/
abstract public function handle(FunctionCallDto $functionCallDto): array;

public function getFunction(): FunctionDto
Expand All @@ -35,13 +33,13 @@ protected function getName(): string
return $this->name;
}

/**
* @return ParameterDto[]
*/
abstract protected function getProperties(): array;

protected function getDescription(): string
{
return $this->name;
return $this->description;
}

/**
* @return PropertyDto[]
*/
abstract protected function getProperties(): array;
}
6 changes: 3 additions & 3 deletions app/LlmDriver/Functions/SearchAndSummarize.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ class SearchAndSummarize extends FunctionContract
{
protected string $name = 'search_and_summarize';

protected string $dscription = 'Used to embed users prompt, search database and return summarized results.';
protected string $description = 'Used to embed users prompt, search database and return summarized results.';

public function handle(FunctionCallDto $functionCallDto): array
{
return [];
}

/**
* @return ParameterDto[]
* @return PropertyDto[]
*/
protected function getProperties(): array
{
return [
new PropertyDto(
name: 'prompt',
description: 'The prompt to search for in the database.',
description: 'The prompt the user is using the search for.',
type: 'string',
required: true,
),
Expand Down
4 changes: 2 additions & 2 deletions app/LlmDriver/LlmDriverClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ protected function getDefaultDriver()
return 'mock';
}

public function getFunctions() : array {
public function getFunctions(): array
{
return [
(new SearchAndSummarize())->getFunction(),
];
}

}
1 change: 0 additions & 1 deletion app/LlmDriver/LlmServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

namespace App\LlmDriver;

use App\LlmDriver\LlmDriverClient;
use Illuminate\Support\ServiceProvider;

class LlmServiceProvider extends ServiceProvider
Expand Down
42 changes: 37 additions & 5 deletions app/LlmDriver/OpenAiClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ public function chat(array $messages): CompletionResponse
{
$functions = $this->getFunctions();


$response = OpenAI::chat()->create([
'model' => $this->getConfig('openai')['completion_model'],
'model' => $this->getConfig('openai')['chat_model'],
'messages' => collect($messages)->map(function ($message) {
return $message->toArray();
})->toArray(),
'tool_choice' => 'auto',
'tools' => $functions,
]);

$results = null;
Expand Down Expand Up @@ -78,12 +79,43 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp
* Since this abstraction layer is based on OpenAi
* Not much needs to happen here
* but on the others I might need to do XML?
* @return array
*/
public function getFunctions() : array {
public function getFunctions(): array
{
$functions = LlmDriverFacade::getFunctions();

return collect($functions)->map(function ($function) {
return $function->toArray();
$function = $function->toArray();
$properties = [];
$required = [];

foreach (data_get($function, 'parameters.properties', []) as $property) {
$name = data_get($property, 'name');

if (data_get($property, 'required', false)) {
$required[] = $name;
}

$properties[$name] = [
'description' => data_get($property, 'description', null),
'type' => data_get($property, 'type', 'string'),
'enum' => data_get($property, 'enum', []),
'default' => data_get($property, 'default', null),
];
}

return [
'type' => 'function',
'function' => [
'name' => data_get($function, 'name'),
'description' => data_get($function, 'description'),
'parameters' => [
'type' => 'object',
'properties' => $properties,
],
'required' => $required,
],
];
})->toArray();
}
}
2 changes: 0 additions & 2 deletions tests/Feature/FunctionDtoTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,4 @@ public function test_dto(): void
$this->assertEquals('bar', $parameterTwo->default);
$this->assertTrue($parameterTwo->required);
}


}
3 changes: 2 additions & 1 deletion tests/Feature/LlmDriverClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public function test_driver_openai(): void
$this->assertInstanceOf(OpenAiClient::class, $results);
}

public function test_get_functions() {
public function test_get_functions()
{
$this->assertNotEmpty(LlmDriverFacade::getFunctions());
}
}
24 changes: 21 additions & 3 deletions tests/Feature/MockClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,27 @@

class MockClientTest extends TestCase
{
/**
* A basic feature test example.
*/
public function test_tools(): void
{

$client = new MockClient();

$results = $client->functionPromptChat(['test']);

$this->assertCount(1, $results);

$this->assertEquals('search_and_summarize', $results[0]['name']);
}

public function test_tool_with_limit(): void
{
$client = new MockClient();

$results = $client->functionPromptChat(['test'], ['search_and_summarize']);

$this->assertCount(0, $results);
}

public function test_embeddings(): void
{

Expand Down
8 changes: 4 additions & 4 deletions tests/Feature/OpenAiClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@

class OpenAiClientTest extends TestCase
{

public function test_get_functions(): void
{
$openaiClient = new \App\LlmDriver\OpenAiClient();
$response = $openaiClient->getFunctions();
$this->assertNotEmpty($response);
$this->assertIsArray($response);
$first = $response[0];
$this->assertArrayHasKey('name', $first);
$this->assertArrayHasKey('description', $first);
$this->assertArrayHasKey('parameters', $first);
$this->assertArrayHasKey('type', $first);
$this->assertArrayHasKey('function', $first);
$expected = get_fixture('openai_client_get_functions.json');

$this->assertEquals($expected, $response);

}

Expand Down
3 changes: 0 additions & 3 deletions tests/Feature/SearchAndSummarizeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

namespace Tests\Feature;

use App\LlmDriver\Functions\ParameterDto;
use App\LlmDriver\Functions\ParametersDto;
use App\LlmDriver\Functions\PropertyDto;
use Illuminate\Foundation\Testing\RefreshDatabase;
use Illuminate\Foundation\Testing\WithFaker;
use Tests\TestCase;

class SearchAndSummarizeTest extends TestCase
Expand Down
20 changes: 2 additions & 18 deletions tests/fixtures/claude_messages_debug.json
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
[
{
"content": "test",
"content": "how does the document define 'Generative AI'",
"role": "user"
},
{
"content": "test 1",
"role": "assistant"
},
{
"role": "user",
"content": "Continuation of search results"
},
{
"content": "test 2",
"role": "assistant"
},
{
"role": "user",
"content": "Continuation of search results"
},
{
"content": "test 3",
"content": "The user uploaded documents that the user will ask questions about. Please keep your answers related to the documents.",
"role": "assistant"
}
]
23 changes: 23 additions & 0 deletions tests/fixtures/openai_client_get_functions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[
{
"type": "function",
"function": {
"name": "search_and_summarize",
"description": "Used to embed users prompt, search database and return summarized results.",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"description": "The prompt the user is using the search for.",
"type": "string",
"enum": [],
"default": ""
}
}
},
"required": [
"prompt"
]
}
}
]
33 changes: 33 additions & 0 deletions tests/fixtures/openai_response_with_functions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"id": "chatcmpl-99MsQDhcyADcsehLfCQAwuyhhTxmB",
"object": "chat.completion",
"created": 1712019918,
"model": "gpt-4-0125-preview",
"systemFingerprint": "fp_f38f4d6482",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": null,
"toolCalls": [
{
"id": "call_u3GOeiE4LaSJvOqV2uOqeXK2",
"type": "function",
"function": {
"name": "search_and_summarize",
"arguments": "{\"prompt\":\"define 'Generative AI'\"}"
}
}
],
"functionCall": null
},
"finishReason": "tool_calls"
}
],
"usage": {
"promptTokens": 101,
"completionTokens": 22,
"totalTokens": 123
}
}

0 comments on commit abfdfcb

Please sign in to comment.