Skip to content

Commit

Permalink
feature: Add support for VoyageAI embeddings (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
milosgajdos committed Mar 24, 2024
1 parent 817f998 commit dfa2b14
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@

# Dependency directories (remove the comment below to include it)
# vendor/
.env
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ Currently supported APIs:
* [x] [OpenAI](https://platform.openai.com/docs/api-reference/embeddings)
* [x] [Cohere AI](https://docs.cohere.com/reference/embed)
* [x] [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings)
* [x] [VoyageAI](https://docs.voyageai.com/reference/embeddings-api)

You can find sample programs that demonstrate how to use the client packages to fetch the embeddings in `cmd` directory of this project.

Finally, the `document` package provides an implementation of simple document text splitters, heavily inspired by the popular [Langchain framework](https://github.com/langchain-ai/langchain).
It's essentially a Go rewrite of character and recursive character text splitters.
It's essentially a Go rewrite of character and recursive character text splitters from the Langchain framework with minor modifications, but more or less identical results.

## Environment variables

Expand All @@ -33,6 +34,7 @@ Each client package lets you initialize a default API client for a specific embe
* `VERTEXAI_TOKEN`: Google Vertex AI API token (can be fetch by `gcloud auth print-access-token` once you've authenticated)
* `VERTEXAI_MODEL_ID`: Embeddings model (at the moment only `textembedding-gecko@00` or `multimodalembedding@001` are available)
* `GOOGLE_PROJECT_ID`: Google Project ID
* `VOYAGE_API_KEY`: VoyageAI API key

## nix

Expand Down
44 changes: 44 additions & 0 deletions cmd/voyage/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/voyage"
)

var (
input string
model string
truncation bool
inputType string
)

func init() {
flag.StringVar(&input, "input", "what is life", "input data")
flag.StringVar(&model, "model", voyage.VoyageV2.String(), "model name")
flag.StringVar(&inputType, "input-type", voyage.DocInput.String(), "input type")
flag.BoolVar(&truncation, "truncate", false, "truncate type")
}

func main() {
flag.Parse()

c := voyage.NewClient()

embReq := &voyage.EmbeddingRequest{
Input: []string{input},
Model: voyage.Model(model),
InputType: voyage.InputType(inputType),
Truncation: truncation,
}

embs, err := c.Embed(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
32 changes: 31 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package embeddings

import "context"
import (
"context"
"encoding/base64"
"encoding/binary"
"fmt"
"math"
)

// Embedder fetches embeddings.
type Embedder[T any] interface {
Expand All @@ -21,3 +27,27 @@ func (e Embedding) ToFloat32() []float32 {
}
return floats
}

// Base64 is base64 encoded embedding string.
type Base64 string

// Decode decodes base64 encoded string into a slice of floats.
func (s Base64) Decode() ([]float64, error) {
decoded, err := base64.StdEncoding.DecodeString(string(s))
if err != nil {
return nil, err
}

if len(decoded)%8 != 0 {
return nil, fmt.Errorf("invalid base64 encoded string length")
}

floats := make([]float64, len(decoded)/8)

for i := 0; i < len(floats); i++ {
bits := binary.LittleEndian.Uint64(decoded[i*8 : (i+1)*8])
floats[i] = math.Float64frombits(bits)
}

return floats, nil
}
32 changes: 2 additions & 30 deletions openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,15 @@ package openai
import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"

"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/request"
)

// EmbeddingString is base64 encoded embedding.
type EmbeddingString string

// Decode decodes base64 encoded string into a slice of floats.
func (s EmbeddingString) Decode() ([]float64, error) {
decoded, err := base64.StdEncoding.DecodeString(string(s))
if err != nil {
return nil, err
}

if len(decoded)%8 != 0 {
return nil, fmt.Errorf("invalid base64 encoded string length")
}

floats := make([]float64, len(decoded)/8)

for i := 0; i < len(floats); i++ {
bits := binary.LittleEndian.Uint64(decoded[i*8 : (i+1)*8])
floats[i] = math.Float64frombits(bits)
}

return floats, nil
}

// Usage tracks API token usage.
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
Expand Down Expand Up @@ -110,7 +82,7 @@ func toEmbeddingResp[T any](resp io.Reader) (*EmbeddingResponse, error) {
}

switch e := any(data).(type) {
case *EmbeddingResponseGen[EmbeddingString]:
case *EmbeddingResponseGen[embeddings.Base64]:
embData := make([]Data, 0, len(e.Data))
for _, d := range e.Data {
floats, err := d.Embedding.Decode()
Expand Down Expand Up @@ -181,7 +153,7 @@ func (c *Client) Embed(ctx context.Context, embReq *EmbeddingRequest) ([]*embedd

switch embReq.EncodingFormat {
case EncodingBase64:
embs, err = toEmbeddingResp[EmbeddingResponseGen[EmbeddingString]](resp.Body)
embs, err = toEmbeddingResp[EmbeddingResponseGen[embeddings.Base64]](resp.Body)
case EncodingFloat:
embs, err = toEmbeddingResp[EmbeddingResponseGen[[]float64]](resp.Body)
default:
Expand Down
85 changes: 85 additions & 0 deletions voyage/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package voyage

import (
"os"

"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/client"
)

const (
// BaseURL is VoyageAI HTTP API base URL.
BaseURL = "https://api.voyageai.com"
// EmbedAPIVersion is the latest stable embedding API version.
EmbedAPIVersion = "v1"
)

// Client is Voyage HTTP API client.
type Client struct {
opts Options
}

// Options are client options
type Options struct {
APIKey string
BaseURL string
Version string
HTTPClient *client.HTTP
}

// Option is functional graph option.
type Option func(*Options)

// NewClient creates a new HTTP API client and returns it.
// By default it reads the Voyage API key from VOYAGE_API_KEY
// env var and uses the default Go http.Client for making API requests.
// You can override the default options via the client methods.
func NewClient(opts ...Option) *Client {
options := Options{
APIKey: os.Getenv("VOYAGE_API_KEY"),
BaseURL: BaseURL,
Version: EmbedAPIVersion,
HTTPClient: client.NewHTTP(),
}

for _, apply := range opts {
apply(&options)
}

return &Client{
opts: options,
}
}

// NewEmbedder creates a client that implements embeddings.Embedder
func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] {
return NewClient(opts...)
}

// WithAPIKey sets the API key.
func WithAPIKey(apiKey string) Option {
return func(o *Options) {
o.APIKey = apiKey
}
}

// WithBaseURL sets the API base URL.
func WithBaseURL(baseURL string) Option {
return func(o *Options) {
o.BaseURL = baseURL
}
}

// WithVersion sets the API version.
func WithVersion(version string) Option {
return func(o *Options) {
o.Version = version
}
}

// WithHTTPClient sets the HTTP client.
func WithHTTPClient(httpClient *client.HTTP) Option {
return func(o *Options) {
o.HTTPClient = httpClient
}
}
52 changes: 52 additions & 0 deletions voyage/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package voyage

import (
"testing"

"github.com/milosgajdos/go-embeddings/client"
"github.com/stretchr/testify/assert"
)

const (
voyageAPIKey = "somekey"
)

func TestClient(t *testing.T) {
t.Setenv("VOYAGE_API_KEY", voyageAPIKey)

t.Run("API key", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.opts.APIKey, voyageAPIKey)

testVal := "foo"
c = NewClient(WithAPIKey(testVal))
assert.Equal(t, c.opts.APIKey, testVal)
})

t.Run("BaseURL", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.opts.BaseURL, BaseURL)

testVal := "http://foo"
c = NewClient(WithBaseURL(testVal))
assert.Equal(t, c.opts.BaseURL, testVal)
})

t.Run("version", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.opts.Version, EmbedAPIVersion)

testVal := "v3"
c = NewClient(WithVersion(testVal))
assert.Equal(t, c.opts.Version, testVal)
})

t.Run("http client", func(t *testing.T) {
c := NewClient()
assert.NotNil(t, c.opts.HTTPClient)

testVal := client.NewHTTP()
c = NewClient(WithHTTPClient(testVal))
assert.NotNil(t, c.opts.HTTPClient)
})
}
Loading

0 comments on commit dfa2b14

Please sign in to comment.