From 7724fc8d39df3a0cdd62c1ff6b2de3c4381c5839 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Sat, 27 Apr 2024 23:38:10 +0100 Subject: [PATCH] feat(ai-proxy): google-gemini support --- .../kong/ai-proxy-google-gemini.yml | 5 + kong/llm/drivers/gemini.lua | 312 ++++++++++++++++++ kong/llm/drivers/shared.lua | 10 + kong/llm/init.lua | 2 +- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 14 + .../expected-requests/gemini/llm-v1-chat.json | 57 ++++ .../gemini/llm-v1-chat.json | 14 + .../real-responses/gemini/llm-v1-chat.json | 34 ++ 8 files changed, 447 insertions(+), 1 deletion(-) create mode 100644 changelog/unreleased/kong/ai-proxy-google-gemini.yml create mode 100644 kong/llm/drivers/gemini.lua create mode 100644 spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json create mode 100644 spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json diff --git a/changelog/unreleased/kong/ai-proxy-google-gemini.yml b/changelog/unreleased/kong/ai-proxy-google-gemini.yml new file mode 100644 index 000000000000..bc4fb06b21c4 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-google-gemini.yml @@ -0,0 +1,5 @@ +message: | + Kong AI Gateway (AI Proxy and associated plugin family) now supports + the Google Gemini "chat" (generateContent) interface. +type: feature +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua new file mode 100644 index 000000000000..5360690b6c63 --- /dev/null +++ b/kong/llm/drivers/gemini.lua @@ -0,0 +1,312 @@ +local _M = {} + +-- imports +local cjson = require("cjson.safe") +local fmt = string.format +local ai_shared = require("kong.llm.drivers.shared") +local socket_url = require("socket.url") +local string_gsub = string.gsub +local buffer = require("string.buffer") +local table_insert = table.insert +local string_lower = string.lower +-- + +-- globals +local DRIVER_NAME = "gemini" +-- + +local _OPENAI_ROLE_MAPPING = { + ["system"] = "system", + ["user"] = "user", + ["assistant"] = "model", +} + +local function to_bard_generation_config(request_table) + return { + ["maxOutputTokens"] = request_table.max_tokens, + ["stopSequences"] = request_table.stop, + ["temperature"] = request_table.temperature, + ["topK"] = request_table.top_k, + ["topP"] = request_table.top_p, + } +end + +local function to_bard_chat_openai(request_table, model_info, route_type) + if request_table then -- try-catch type mechanism + local new_r = {} + + if request_table.messages and #request_table.messages > 0 then + local system_prompt + + for i, v in ipairs(request_table.messages) do + + -- for 'system', we just concat them all into one Gemini instruction + if v.role and v.role == "system" then + system_prompt = system_prompt or buffer.new() + system_prompt:put(v.content or "") + else + -- for any other role, just construct the chat history as 'parts.text' type + new_r.contents = new_r.contents or {} + table_insert(new_r.contents, { + role = _OPENAI_ROLE_MAPPING[v.role or "user"], -- default to 'user' + parts = { + { + text = v.content or "" + }, + }, + }) + end + end + + ---- TODO for some reason this is broken? + ---- I think it's something to do with which "regional" endpoint of Gemini you hit... + -- if system_prompt then + -- new_r.systemInstruction = { + -- parts = { + -- { + -- text = system_prompt:get(), + -- }, + -- }, + -- } + -- end + ---- + + end + + new_r.generationConfig = to_bard_generation_config(request_table) + + kong.log.debug(cjson.encode(new_r)) + + return new_r, "application/json", nil + end + + local err = "empty request table received for transformation" + ngx.log(ngx.ERR, err) + return nil, nil, err +end + +local function from_bard_chat_openai(response, model_info, route_type) + local response, err = cjson.decode(response) + + if err then + local err_client = "failed to decode response from Gemini" + ngx.log(ngx.ERR, fmt("%s: %s", err_client, err)) + return nil, err_client + end + + -- messages/choices table is only 1 size, so don't need to static allocate + local messages = {} + messages.choices = {} + + if response.candidates + and #response.candidates > 0 + and response.candidates[1].content + and response.candidates[1].content.parts + and #response.candidates[1].content.parts > 0 + and response.candidates[1].content.parts[1].text then + + messages.choices[1] = { + index = 0, + message = { + role = "assistant", + content = response.candidates[1].content.parts[1].text, + }, + finish_reason = string_lower(response.candidates[1].finishReason), + } + messages.object = "chat.completion" + messages.model = model_info.name + + else -- probably a server fault or other unexpected response + local err = "no generation candidates received from Gemini, or max_tokens too short" + ngx.log(ngx.ERR, err) + return nil, err + end + + return cjson.encode(messages) +end + +local function to_bard_chat_bard(request_table, model_info, route_type) + return nil, nil, "bard to bard not yet implemented" +end + +local function from_bard_chat_bard(request_table, model_info, route_type) + return nil, nil, "bard to bard not yet implemented" +end + +local transformers_to = { + ["llm/v1/chat"] = to_bard_chat_openai, + ["gemini/v1/chat"] = to_gemini_chat_bard, +} + +local transformers_from = { + ["llm/v1/chat"] = from_bard_chat_openai, + ["gemini/v1/chat"] = from_gemini_chat_bard, +} + +function _M.from_format(response_string, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from ", model_info.provider, "://", route_type, " type to kong") + + -- MUST return a string, to set as the response body + if not transformers_from[route_type] then + return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) + end + + local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info, route_type) + if not ok or err then + return nil, fmt("transformation failed from type %s://%s: %s", + model_info.provider, + route_type, + err or "unexpected_error" + ) + end + + return response_string, nil +end + +function _M.to_format(request_table, model_info, route_type) + ngx.log(ngx.DEBUG, "converting from kong type to ", model_info.provider, "/", route_type) + + if route_type == "preserve" then + -- do nothing + return request_table, nil, nil + end + + if not transformers_to[route_type] then + return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) + end + + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + + local ok, response_object, content_type, err = pcall( + transformers_to[route_type], + request_table, + model_info + ) + if err or (not ok) then + return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) + end + + return response_object, content_type, nil +end + +function _M.subrequest(body, conf, http_opts, return_res_table) + -- use shared/standard subrequest routine + local body_string, err + + if type(body) == "table" then + body_string, err = cjson.encode(body) + if err then + return nil, nil, "failed to parse body to json: " .. err + end + elseif type(body) == "string" then + body_string = body + else + return nil, nil, "body must be table or string" + end + + -- may be overridden + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( + "%s%s", + ai_shared.upstream_url_format[DRIVER_NAME], + ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + ) + + local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method + + local headers = { + ["Accept"] = "application/json", + ["Content-Type"] = "application/json", + } + + if conf.auth and conf.auth.header_name then + headers[conf.auth.header_name] = conf.auth.header_value + end + + local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table) + if err then + return nil, nil, "request to ai service failed: " .. err + end + + if return_res_table then + return res, res.status, nil, httpc + else + -- At this point, the entire request / response is complete and the connection + -- will be closed or back on the connection pool. + local status = res.status + local body = res.body + + if status > 299 then + return body, res.status, "status code " .. status + end + + return body, res.status, nil + end +end + +function _M.header_filter_hooks(body) + -- nothing to parse in header_filter phase +end + +function _M.post_request(conf) + if ai_shared.clear_response_headers[DRIVER_NAME] then + for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do + kong.response.clear_header(v) + end + end +end + +function _M.pre_request(conf, body) + kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli + + return true, nil +end + +-- returns err or nil +function _M.configure_request(conf) + local parsed_url + + if (conf.model.options and conf.model.options.upstream_url) then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + local path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and fmt(ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, conf.model.name) + or "/" + if not path then + return nil, fmt("operation %s is not supported for openai provider", conf.route_type) + end + + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = path + end + + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") + + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + + local auth_header_name = conf.auth and conf.auth.header_name + local auth_header_value = conf.auth and conf.auth.header_value + local auth_param_name = conf.auth and conf.auth.param_name + local auth_param_value = conf.auth and conf.auth.param_value + local auth_param_location = conf.auth and conf.auth.param_location + + if auth_header_name and auth_header_value then + kong.service.request.set_header(auth_header_name, auth_header_value) + end + + if auth_param_name and auth_param_value and auth_param_location == "query" then + local query_table = kong.request.get_query() + query_table[auth_param_name] = auth_param_value + kong.service.request.set_query(query_table) + end + + -- if auth_param_location is "form", it will have already been set in a global pre-request hook + return true, nil +end + +return _M diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index f4696f116c10..228aed861ce5 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -55,6 +55,7 @@ _M.upstream_url_format = { anthropic = "https://api.anthropic.com:443", cohere = "https://api.cohere.com:443", azure = "https://%s.openai.azure.com:443/openai/deployments/%s", + gemini = "https://generativelanguage.googleapis.com", } _M.operation_map = { @@ -98,6 +99,12 @@ _M.operation_map = { method = "POST", }, }, + gemini = { + ["llm/v1/chat"] = { + path = "/v1/models/%s:generateContent", -- /v1/models/gemini-pro:generateContent, + method = "POST", + }, + }, } _M.clear_response_headers = { @@ -113,6 +120,9 @@ _M.clear_response_headers = { mistral = { "Set-Cookie", }, + gemini = { + "Set-Cookie", + }, } --- diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 6ac1a1ff0b96..0a67f232d19b 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -120,7 +120,7 @@ local model_schema = { type = "string", description = "AI provider request format - Kong translates " .. "requests to and from the specified backend compatible formats.", required = true, - one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2" }}}, + one_of = { "openai", "azure", "anthropic", "cohere", "mistral", "llama2", "gemini" }}}, { name = { type = "string", description = "Model name to execute.", diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 9ff754a1407f..44277c8e4a0b 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -223,6 +223,20 @@ local FORMATS = { }, }, }, + gemini = { + ["llm/v1/chat"] = { + config = { + name = "gemini-pro", + provider = "gemini", + options = { + max_tokens = 8192, + temperature = 0.8, + top_k = 1, + top_p = 0.6, + }, + }, + }, + }, } local STREAMS = { diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..f236df678a4d --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-requests/gemini/llm-v1-chat.json @@ -0,0 +1,57 @@ +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "What is 1 + 2?" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Multiply that by 2" + } + ] + }, + { + "role": "model", + "parts": [ + { + "text": "Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!" + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "Why can't you divide by zero?" + } + ] + } + ], + "generationConfig": { + "temperature": 0.8, + "topK": 1, + "topP": 0.6, + "maxOutputTokens": 8192 + }, + "systemInstruction": { + "parts": [ + { + "text": "You are a mathematician." + } + ] + } +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..90a1656d2a37 --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/expected-responses/gemini/llm-v1-chat.json @@ -0,0 +1,14 @@ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n", + "role": "assistant" + } + } + ], + "model": "gemini-pro", + "object": "chat.completion" +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json new file mode 100644 index 000000000000..80781b6eb72a --- /dev/null +++ b/spec/fixtures/ai-proxy/unit/real-responses/gemini/llm-v1-chat.json @@ -0,0 +1,34 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Ah, vous voulez savoir le double de ce résultat ? Eh bien, le double de 2 est **4**. \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ] + } \ No newline at end of file