Skip to content

Commit

Permalink
OpenAi setup now but not 100% sure this is still worth it let me see …
Browse files Browse the repository at this point in the history
…what Ollama shows
  • Loading branch information
alnutile committed Jul 14, 2024
1 parent a74a216 commit 52f2c92
Show file tree
Hide file tree
Showing 13 changed files with 573 additions and 39 deletions.
5 changes: 5 additions & 0 deletions Modules/LlmDriver/app/BaseClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ public function getFunctions(): array
{
$functions = LlmDriverFacade::getFunctions();

return $this->remapFunctions($functions);
}

public function remapFunctions(array $functions): array
{
return collect($functions)->map(function ($function) {
$function = $function->toArray();
$properties = [];
Expand Down
191 changes: 163 additions & 28 deletions Modules/LlmDriver/app/OpenAiClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

use App\Models\Setting;
use Illuminate\Http\Client\Pool;
use Illuminate\Http\Client\Response;
use Illuminate\Support\Facades\Http;
use Illuminate\Support\Facades\Log;
use Laravel\Pennant\Feature;
use LlmLaraHub\LlmDriver\Functions\FunctionDto;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
Expand All @@ -22,19 +25,52 @@ class OpenAiClient extends BaseClient
*/
public function chat(array $messages): CompletionResponse
{
$token = Setting::getSecret('openai', 'api_key');

$response = OpenAI::chat()->create([
if (is_null($token)) {
throw new \Exception('Missing open ai api key');
}

$payload = [
'model' => $this->getConfig('openai')['models']['chat_model'],
'messages' => $this->messagesToArray($messages),
]);
];

$results = null;
$payload = $this->modifyPayload($payload);

foreach ($response->choices as $result) {
$results = $result->message->content;
$response = Http::withHeaders([
'Content-type' => 'application/json',
])
->withToken($token)
->baseUrl($this->baseUrl)
->timeout(240)
->retry(3, function (int $attempt, \Exception $exception) {
Log::info('OpenAi API Error going to retry', [
'attempt' => $attempt,
'error' => $exception->getMessage(),
]);

return 60000;
})
->post('/chat/completions', $payload);

if ($response->failed()) {
Log::error('OpenAi API Error ', [
'error' => $response->body(),
]);

throw new \Exception('OpenAi API Error Chat');
}

return new CompletionResponse($results);
[$data, $tool_used, $stop_reason] = $this->getContentAndToolTypeFromResults($response);

return CompletionResponse::from([
'content' => $data,
'tool_used' => $tool_used,
'stop_reason' => $stop_reason,
'input_tokens' => data_get($response, 'usage.prompt_tokens', null),
'output_tokens' => data_get($response, 'usage.completion_tokens', null),
]);
}

public function embedData(string $data): EmbeddingsResponseDto
Expand Down Expand Up @@ -72,6 +108,15 @@ public function completionPool(array $prompts, int $temperature = 0): array

$responses = Http::pool(function (Pool $pool) use ($prompts, $token) {
foreach ($prompts as $prompt) {
$payload = [
'model' => $this->getConfig('openai')['models']['completion_model'],
'messages' => [
['role' => 'user', 'content' => $prompt],
],
];

$payload = $this->modifyPayload($payload);

$pool->withHeaders([
'content-type' => 'application/json',
'Authorization' => 'Bearer '.$token,
Expand All @@ -86,12 +131,7 @@ public function completionPool(array $prompts, int $temperature = 0): array

return 60000;
})
->post('/chat/completions', [
'model' => $this->getConfig('openai')['models']['completion_model'],
'messages' => [
['role' => 'user', 'content' => $prompt],
],
]);
->post('/chat/completions', $payload);
}

});
Expand All @@ -100,13 +140,14 @@ public function completionPool(array $prompts, int $temperature = 0): array

foreach ($responses as $index => $response) {
if ($response->ok()) {
$response = $response->json();
foreach (data_get($response, 'choices', []) as $result) {
$result = data_get($result, 'message.content', '');
$results[] = CompletionResponse::from([
'content' => $result,
]);
}
[$data, $tool_used, $stop_reason] = $this->getContentAndToolTypeFromResults($response);
$results[] = CompletionResponse::from([
'content' => $data,
'tool_used' => $tool_used,
'stop_reason' => $stop_reason,
'input_tokens' => data_get($response, 'usage.prompt_tokens', null),
'output_tokens' => data_get($response, 'usage.completion_tokens', null),
]);
} else {
Log::error('OpenAi API Error ', [
'index' => $index,
Expand All @@ -133,6 +174,8 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp
],
];

$payload = $this->modifyPayload($payload);

$response = Http::withHeaders([
'Content-type' => 'application/json',
])
Expand All @@ -149,15 +192,84 @@ public function completion(string $prompt, int $temperature = 0): CompletionResp
})
->post('/chat/completions', $payload);

$results = null;
if ($response->failed()) {
Log::error('OpenAi API Error ', [
'error' => $response->body(),
]);

$response = $response->json();
throw new \Exception('OpenAi API Error Chat');
}

[$data, $tool_used, $stop_reason] = $this->getContentAndToolTypeFromResults($response);

return CompletionResponse::from([
'content' => $data,
'tool_used' => $tool_used,
'stop_reason' => $stop_reason,
'input_tokens' => data_get($response, 'usage.prompt_tokens', null),
'output_tokens' => data_get($response, 'usage.completion_tokens', null),
]);
}

public function getContentAndToolTypeFromResults(Response $results): array
{
$results = $results->json();
$tool_used = null;
$stop_reason = data_get($results, 'choices.0.finish_reason', 'stop');
$tool_calls = data_get($results, 'choices.0.message.tool_calls', []);

if ($stop_reason === 'tool_calls' || ! empty($tool_calls)) {
/**
* @TOOD
* The tool should be used here to get the
* output since it might be different
* for each tool
* Right now it assumes the JSON one is being used
*/
foreach ($results['choices'] as $content) {
$tool_used = data_get($content, 'message.tool_calls.0.function.name');
$data = json_encode(data_get($content, 'message.tool_calls.0.function.arguments', []), JSON_THROW_ON_ERROR);
}
} else {
foreach (data_get($results, 'choices', []) as $result) {
$data = data_get($result, 'message.content', '');
}
}

return [$data, $tool_used, $stop_reason];
}

foreach (data_get($response, 'choices', []) as $result) {
$results = data_get($result, 'message.content', '');
public function modifyPayload(array $payload): array
{
Log::info('LlmDriver::OpenAi::modifyPayload', [
'payload' => $payload,
'forceTool' => $this->forceTool,
]);

if (! empty($this->forceTool)) {
$function = [$this->forceTool];
$function = $this->remapFunctions($function);

$payload['tools'] = $function;
$payload['tool_choice'] = [
'type' => 'function',
'function' => [
'name' => $this->forceTool->name,
],
];
} else {
//I should add all the tools here?
if (Feature::active('all_tools')) {
$payload['tools'] = $this->getFunctions();
$payload['tool_choice'] = 'auto';
} else {
//$payload['tool_choice'] = 'none';
}
}

return new CompletionResponse($results);
$payload = $this->addJsonFormat($payload);

return $payload;
}

/**
Expand Down Expand Up @@ -215,11 +327,22 @@ public function getFunctions(): array
{
$functions = LlmDriverFacade::getFunctions();

return $this->remapFunctions($functions);

}

/**
* @param FunctionDto[] $functions
*/
public function remapFunctions(array $functions): array
{
return collect($functions)->map(function ($function) {
$function = $function->toArray();
$properties = [];
$required = [];

$type = data_get($function, 'parameters.type', 'object');

foreach (data_get($function, 'parameters.properties', []) as $property) {
$name = data_get($property, 'name');

Expand All @@ -230,8 +353,21 @@ public function getFunctions(): array
$properties[$name] = [
'description' => data_get($property, 'description', null),
'type' => data_get($property, 'type', 'string'),
'enum' => data_get($property, 'enum', []),
'default' => data_get($property, 'default', null),
];
}

$itemsOrProperties = $properties;

if ($type === 'array') {
$itemsOrProperties = [
'results' => [
'type' => 'array',
'description' => 'The results of prompt',
'items' => [
'type' => 'object',
'properties' => $properties,
],
],
];
}

Expand All @@ -242,9 +378,8 @@ public function getFunctions(): array
'description' => data_get($function, 'description'),
'parameters' => [
'type' => 'object',
'properties' => $properties,
'properties' => $itemsOrProperties,
],
'required' => $required,
],
];
})->toArray();
Expand Down
1 change: 0 additions & 1 deletion Modules/LlmDriver/tests/Feature/ClaudeClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ public function test_remap_array(): void

$results = (new ClaudeClient)->remapFunctions([$dto]);


$this->assertEquals(
$shouldBe,
$results
Expand Down
50 changes: 41 additions & 9 deletions Modules/LlmDriver/tests/Feature/OpenAiClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

use App\Models\Setting;
use Illuminate\Support\Facades\Http;
use LlmLaraHub\LlmDriver\Functions\FunctionDto;
use LlmLaraHub\LlmDriver\Functions\ParametersDto;
use LlmLaraHub\LlmDriver\Functions\PropertyDto;
use LlmLaraHub\LlmDriver\Requests\MessageInDto;
use LlmLaraHub\LlmDriver\Responses\CompletionResponse;
use LlmLaraHub\LlmDriver\Responses\EmbeddingsResponseDto;
Expand Down Expand Up @@ -70,20 +73,52 @@ public function test_completion(): void
$this->assertInstanceOf(CompletionResponse::class, $response);
}

public function test_remap_array(): void
{
$dto = FunctionDto::from([
'name' => 'reporting_json',
'description' => 'JSON Summary of the report',
'parameters' => ParametersDto::from([
'type' => 'array',
'properties' => [
PropertyDto::from([
'name' => 'title',
'description' => 'The title of the section',
'type' => 'string',
'required' => true,
]),
PropertyDto::from([
'name' => 'content',
'description' => 'The content of the section',
'type' => 'string',
'required' => true,
]),
],
]),
]);

$openaiClient = new \LlmLaraHub\LlmDriver\OpenAiClient();
$response = $openaiClient->remapFunctions([$dto]);
$shouldBe = get_fixture('openai_payload_modified.json');
$shouldBe = data_get($shouldBe, 'tools', []);
$this->assertEquals($shouldBe, $response);

}

public function test_chat(): void
{
OpenAI::fake([
ChatCreateResponse::fake([
Http::fake([
'api.openai.com/*' => Http::response([
'choices' => [
[
'message' => [
'content' => 'awesome!',
],
'messages' => [
'content' => 'Foo bar',
],
],
]),
]);

Http::preventStrayRequests();

$openaiClient = new \LlmLaraHub\LlmDriver\OpenAiClient();
$response = $openaiClient->chat([
MessageInDto::from([
Expand All @@ -107,9 +142,6 @@ public function test_functions_prompt(): void
'choices' => data_get($data, 'choices', []),
];

// OpenAI::fake([
// ChatCreateResponse::fake($response)
// ]);
OpenAI::fake([
ChatCreateResponse::fake([
'choices' => [
Expand Down
8 changes: 8 additions & 0 deletions app/Providers/AppServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ public function boot(): void
return config('llmdriver.features.reference_collection'); //just not ready yet
});

Feature::define('all_tools', function (User $user) {
if (config('llmdriver.features.all_tools')) {
return true;
}

return false;
});

Feature::define('verification_prompt_tags', function (User $user) {
return false;
});
Expand Down
Loading

0 comments on commit 52f2c92

Please sign in to comment.