diff --git a/controllers/client.rb b/controllers/client.rb index d43fe13..11a99e0 100644 --- a/controllers/client.rb +++ b/controllers/client.rb @@ -17,6 +17,8 @@ class Client DEFAULT_SERVICE_VERSION = 'v1' + attr_reader :base_address + def initialize(config) @service = config[:credentials][:service] @@ -54,12 +56,12 @@ def initialize(config) @service_version = config.dig(:credentials, :version) || DEFAULT_SERVICE_VERSION - @base_address = case @service - when 'vertex-ai-api' - "https://#{config[:credentials][:region]}-aiplatform.googleapis.com/#{@service_version}/projects/#{@project_id}/locations/#{config[:credentials][:region]}" - when 'generative-language-api' - "https://generativelanguage.googleapis.com/#{@service_version}" - end + @base_address = config[:credentials][:base_address] || case @service + when 'vertex-ai-api' + "https://#{config[:credentials][:region]}-aiplatform.googleapis.com/#{@service_version}/projects/#{@project_id}/locations/#{config[:credentials][:region]}" + when 'generative-language-api' + "https://generativelanguage.googleapis.com/#{@service_version}" + end @model_address = case @service when 'vertex-ai-api' @@ -81,6 +83,8 @@ def initialize(config) else {} end + + @custom_headers = config.dig(:options, :headers) || {} end def avoid_conflicting_credentials!(credentials) @@ -179,7 +183,7 @@ def request(path, payload, server_sent_events: nil, request_method: 'POST', &cal method_to_call = request_method.to_s.strip.downcase.to_sym response = Faraday.new(request: @request_options) do |faraday| - faraday.adapter @faraday_adapter + faraday.adapter(*@faraday_adapter) faraday.response :raise_error end.send(method_to_call) do |request| request.url url @@ -187,6 +191,7 @@ def request(path, payload, server_sent_events: nil, request_method: 'POST', &cal if @authentication == :service_account || @authentication == :default_credentials request.headers['Authorization'] = "Bearer #{@authorizer.fetch_access_token!['access_token']}" end + @custom_headers.each { |key, value| request.headers[key] = value } request.body = payload.to_json unless payload.nil? diff --git a/spec/controllers/client_spec.rb b/spec/controllers/client_spec.rb index 6c7df9c..6cf03aa 100644 --- a/spec/controllers/client_spec.rb +++ b/spec/controllers/client_spec.rb @@ -45,4 +45,64 @@ "You must choose either 'file_contents', or 'file_path'." ) end + + describe 'custom base address' do + it 'uses the custom base address when provided' do + custom_base_address = 'https://custom-gemini-api.example.com/v1' + client = described_class.new( + credentials: { + service: 'vertex-ai-api', + region: 'us-east4', + base_address: custom_base_address, + api_key: 'key' + }, + options: { model: 'gemini-pro' } + ) + + expect(client.base_address).to eq(custom_base_address) + end + + it 'uses the default base address when not provided' do + client = described_class.new( + credentials: { + service: 'vertex-ai-api', + region: 'us-east4', + api_key: 'key' + }, + options: { model: 'gemini-pro' } + ) + + expect(client.base_address).to include('aiplatform.googleapis.com') + end + end + + describe 'custom headers' do + let(:stubs) { Faraday::Adapter::Test::Stubs.new } + let(:faraday_test_adapter) { :test } + + it 'sends custom headers with the request' do + custom_headers = { 'X-Custom-Header' => 'CustomValue' } + + stubs.post(/.*/) do |env| + expect(env.request_headers).to include(custom_headers) + [200, {}, '{}'] + end + + client = described_class.new( + credentials: { + service: 'vertex-ai-api', + region: 'us-east4', + api_key: 'key' + }, + options: { + model: 'gemini-pro', + headers: custom_headers, + connection: { adapter: [:test, stubs] } + } + ) + + client.predict({ content: 'Test' }) + stubs.verify_stubbed_calls + end + end end