From 72ce2ed38fb94619f1d22c453ef635b7ed2bcb1c Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Tue, 9 Apr 2024 14:33:44 +0100 Subject: [PATCH] feat: add ollama embeddings API Signed-off-by: Milos Gajdos --- README.md | 14 +++++++-- cmd/ollama/main.go | 42 +++++++++++++++++++++++++++ ollama/client.go | 61 +++++++++++++++++++++++++++++++++++++++ ollama/client_test.go | 28 ++++++++++++++++++ ollama/embedding.go | 67 +++++++++++++++++++++++++++++++++++++++++++ ollama/error.go | 17 +++++++++++ openai/client.go | 2 +- 7 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 cmd/ollama/main.go create mode 100644 ollama/client.go create mode 100644 ollama/client_test.go create mode 100644 ollama/embedding.go create mode 100644 ollama/error.go diff --git a/README.md b/README.md index e70594a..ffcad48 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,10 @@ This project provides an implementation of API clients for fetching embeddings f Currently supported APIs: * [x] [OpenAI](https://platform.openai.com/docs/api-reference/embeddings) -* [x] [Cohere AI](https://docs.cohere.com/reference/embed) -* [x] [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings) +* [x] [Cohere](https://docs.cohere.com/reference/embed) +* [x] [Google Vertex](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings) * [x] [VoyageAI](https://docs.voyageai.com/reference/embeddings-api) +* [x] [Ollama](https://ollama.com/) You can find sample programs that demonstrate how to use the client packages to fetch the embeddings in `cmd` directory of this project. @@ -19,7 +20,10 @@ It's essentially a Go rewrite of character and recursive character text splitter ## Environment variables -Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables. +> [!NOTE] +> Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables + +Here's a list of the env vars for each supported client ### OpenAI @@ -36,6 +40,10 @@ Each client package lets you initialize a default API client for a specific embe * `GOOGLE_PROJECT_ID`: Google Project ID * `VOYAGE_API_KEY`: VoyageAI API key +### Voyage + +* `VOYAGE_API_KEY`: Voyage AI API key + ## nix The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds. diff --git a/cmd/ollama/main.go b/cmd/ollama/main.go new file mode 100644 index 0000000..b60c445 --- /dev/null +++ b/cmd/ollama/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + + "github.com/milosgajdos/go-embeddings/ollama" +) + +var ( + prompt string + model string +) + +func init() { + flag.StringVar(&prompt, "prompt", "what is life", "input prompt") + flag.StringVar(&model, "model", "", "model name") +} + +func main() { + flag.Parse() + + if model == "" { + log.Fatal("missing ollama model") + } + + c := ollama.NewClient() + + embReq := &ollama.EmbeddingRequest{ + Prompt: prompt, + Model: model, + } + + embs, err := c.Embed(context.Background(), embReq) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("got %d embeddings", len(embs)) +} diff --git a/ollama/client.go b/ollama/client.go new file mode 100644 index 0000000..ebb6487 --- /dev/null +++ b/ollama/client.go @@ -0,0 +1,61 @@ +package ollama + +import ( + "github.com/milosgajdos/go-embeddings" + "github.com/milosgajdos/go-embeddings/client" +) + +const ( + // BaseURL is Ollama HTTP API embeddings base URL. + BaseURL = "http://localhost:11434/api" +) + +// Client is an OpenAI HTTP API client. +type Client struct { + opts Options +} + +type Options struct { + BaseURL string + HTTPClient *client.HTTP +} + +// Option is functional graph option. +type Option func(*Options) + +// NewClient creates a new Ollama HTTP API client and returns it. +// You can override the default options via the client methods. +func NewClient(opts ...Option) *Client { + options := Options{ + BaseURL: BaseURL, + HTTPClient: client.NewHTTP(), + } + + for _, apply := range opts { + apply(&options) + } + + return &Client{ + opts: options, + } +} + +// NewEmbedder creates a client that implements embeddings.Embedder +func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] { + return NewClient(opts...) +} + +// WithBaseURL sets the API base URL. +func WithBaseURL(baseURL string) Option { + return func(o *Options) { + o.BaseURL = baseURL + } +} + +// WithVersion sets the API version. +// WithHTTPClient sets the HTTP client. +func WithHTTPClient(httpClient *client.HTTP) Option { + return func(o *Options) { + o.HTTPClient = httpClient + } +} diff --git a/ollama/client_test.go b/ollama/client_test.go new file mode 100644 index 0000000..0f57544 --- /dev/null +++ b/ollama/client_test.go @@ -0,0 +1,28 @@ +package ollama + +import ( + "testing" + + "github.com/milosgajdos/go-embeddings/client" + "github.com/stretchr/testify/assert" +) + +func TestClient(t *testing.T) { + t.Run("BaseURL", func(t *testing.T) { + c := NewClient() + assert.Equal(t, c.opts.BaseURL, BaseURL) + + testVal := "http://foo" + c = NewClient(WithBaseURL(testVal)) + assert.Equal(t, c.opts.BaseURL, testVal) + }) + + t.Run("http client", func(t *testing.T) { + c := NewClient() + assert.NotNil(t, c.opts.HTTPClient) + + testVal := client.NewHTTP() + c = NewClient(WithHTTPClient(testVal)) + assert.NotNil(t, c.opts.HTTPClient) + }) +} diff --git a/ollama/embedding.go b/ollama/embedding.go new file mode 100644 index 0000000..4a076df --- /dev/null +++ b/ollama/embedding.go @@ -0,0 +1,67 @@ +package ollama + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + + "github.com/milosgajdos/go-embeddings" + "github.com/milosgajdos/go-embeddings/request" +) + +// EmbeddingRequest is serialized and sent to the API server. +type EmbeddingRequest struct { + Prompt any `json:"prompt"` + Model string `json:"model"` +} + +// EmbedddingResponse received from API. +type EmbedddingResponse struct { + Embedding []float64 `json:"embedding"` +} + +// ToEmbeddings converts the API response, +// into a slice of embeddings and returns it. +func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) { + floats := make([]float64, len(e.Embedding)) + copy(floats, e.Embedding) + return []*embeddings.Embedding{ + {Vector: floats}, + }, nil +} + +// Embed returns embeddings for every object in EmbeddingRequest. +func (c *Client) Embed(ctx context.Context, embReq *EmbeddingRequest) ([]*embeddings.Embedding, error) { + u, err := url.Parse(c.opts.BaseURL + "/embeddings") + if err != nil { + return nil, err + } + + var body = &bytes.Buffer{} + enc := json.NewEncoder(body) + enc.SetEscapeHTML(false) + if err := enc.Encode(embReq); err != nil { + return nil, err + } + + options := []request.Option{} + req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...) + if err != nil { + return nil, err + } + + resp, err := request.Do[APIError](c.opts.HTTPClient, req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + e := new(EmbedddingResponse) + if err := json.NewDecoder(resp.Body).Decode(e); err != nil { + return nil, err + } + + return e.ToEmbeddings() +} diff --git a/ollama/error.go b/ollama/error.go new file mode 100644 index 0000000..7962379 --- /dev/null +++ b/ollama/error.go @@ -0,0 +1,17 @@ +package ollama + +import "encoding/json" + +// APIError is Ollama API error. +type APIError struct { + ErrorMessage string `json:"error"` +} + +// Error implements errors interface. +func (e APIError) Error() string { + b, err := json.Marshal(e) + if err != nil { + return "unknown error" + } + return string(b) +} diff --git a/openai/client.go b/openai/client.go index 6210c3d..7c9239f 100644 --- a/openai/client.go +++ b/openai/client.go @@ -32,7 +32,7 @@ type Options struct { // Option is functional graph option. type Option func(*Options) -// NewClient creates a new HTTP API client and returns it. +// NewClient creates a new OpenAI HTTP API client and returns it. // By default it reads the OpenAI API key from OPENAI_API_KEY // env var and uses the default Go http.Client for making API requests. // You can override the default options via the client methods.