Skip to content

Commit

Permalink
Add custom headers and base address customization
Browse files Browse the repository at this point in the history
  • Loading branch information
timabdulla committed Sep 13, 2024
1 parent d0d2602 commit 9be18cd
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 7 deletions.
19 changes: 12 additions & 7 deletions controllers/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Client

DEFAULT_SERVICE_VERSION = 'v1'

attr_reader :base_address

def initialize(config)
@service = config[:credentials][:service]

Expand Down Expand Up @@ -54,12 +56,12 @@ def initialize(config)

@service_version = config.dig(:credentials, :version) || DEFAULT_SERVICE_VERSION

@base_address = case @service
when 'vertex-ai-api'
"https://#{config[:credentials][:region]}-aiplatform.googleapis.com/#{@service_version}/projects/#{@project_id}/locations/#{config[:credentials][:region]}"
when 'generative-language-api'
"https://generativelanguage.googleapis.com/#{@service_version}"
end
@base_address = config[:credentials][:base_address] || case @service
when 'vertex-ai-api'
"https://#{config[:credentials][:region]}-aiplatform.googleapis.com/#{@service_version}/projects/#{@project_id}/locations/#{config[:credentials][:region]}"
when 'generative-language-api'
"https://generativelanguage.googleapis.com/#{@service_version}"
end

@model_address = case @service
when 'vertex-ai-api'
Expand All @@ -81,6 +83,8 @@ def initialize(config)
else
{}
end

@custom_headers = config.dig(:options, :headers) || {}
end

def avoid_conflicting_credentials!(credentials)
Expand Down Expand Up @@ -179,14 +183,15 @@ def request(path, payload, server_sent_events: nil, request_method: 'POST', &cal
method_to_call = request_method.to_s.strip.downcase.to_sym

response = Faraday.new(request: @request_options) do |faraday|
faraday.adapter @faraday_adapter
faraday.adapter(*@faraday_adapter)
faraday.response :raise_error
end.send(method_to_call) do |request|
request.url url
request.headers['Content-Type'] = 'application/json'
if @authentication == :service_account || @authentication == :default_credentials
request.headers['Authorization'] = "Bearer #{@authorizer.fetch_access_token!['access_token']}"
end
@custom_headers.each { |key, value| request.headers[key] = value }

request.body = payload.to_json unless payload.nil?

Expand Down
60 changes: 60 additions & 0 deletions spec/controllers/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,64 @@
"You must choose either 'file_contents', or 'file_path'."
)
end

describe 'custom base address' do
it 'uses the custom base address when provided' do
custom_base_address = 'https://custom-gemini-api.example.com/v1'
client = described_class.new(
credentials: {
service: 'vertex-ai-api',
region: 'us-east4',
base_address: custom_base_address,
api_key: 'key'
},
options: { model: 'gemini-pro' }
)

expect(client.base_address).to eq(custom_base_address)
end

it 'uses the default base address when not provided' do
client = described_class.new(
credentials: {
service: 'vertex-ai-api',
region: 'us-east4',
api_key: 'key'
},
options: { model: 'gemini-pro' }
)

expect(client.base_address).to include('aiplatform.googleapis.com')
end
end

describe 'custom headers' do
let(:stubs) { Faraday::Adapter::Test::Stubs.new }
let(:faraday_test_adapter) { :test }

it 'sends custom headers with the request' do
custom_headers = { 'X-Custom-Header' => 'CustomValue' }

stubs.post(/.*/) do |env|
expect(env.request_headers).to include(custom_headers)
[200, {}, '{}']
end

client = described_class.new(
credentials: {
service: 'vertex-ai-api',
region: 'us-east4',
api_key: 'key'
},
options: {
model: 'gemini-pro',
headers: custom_headers,
connection: { adapter: [:test, stubs] }
}
)

client.predict({ content: 'Test' })
stubs.verify_stubbed_calls
end
end
end

0 comments on commit 9be18cd

Please sign in to comment.