Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ollama embeddings API #23

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ This project provides an implementation of API clients for fetching embeddings f

Currently supported APIs:
* [x] [OpenAI](https://platform.openai.com/docs/api-reference/embeddings)
* [x] [Cohere AI](https://docs.cohere.com/reference/embed)
* [x] [Google Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings)
* [x] [Cohere](https://docs.cohere.com/reference/embed)
* [x] [Google Vertex](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings)
* [x] [VoyageAI](https://docs.voyageai.com/reference/embeddings-api)
* [x] [Ollama](https://ollama.com/)

You can find sample programs that demonstrate how to use the client packages to fetch the embeddings in `cmd` directory of this project.

Expand All @@ -19,7 +20,10 @@ It's essentially a Go rewrite of character and recursive character text splitter

## Environment variables

Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables.
> [!NOTE]
> Each client package lets you initialize a default API client for a specific embeddings provider by reading the API keys from environment variables

Here's a list of the env vars for each supported client

### OpenAI

Expand All @@ -36,6 +40,10 @@ Each client package lets you initialize a default API client for a specific embe
* `GOOGLE_PROJECT_ID`: Google Project ID
* `VOYAGE_API_KEY`: VoyageAI API key

### Voyage

* `VOYAGE_API_KEY`: Voyage AI API key

## nix

The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds.
Expand Down
42 changes: 42 additions & 0 deletions cmd/ollama/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/ollama"
)

var (
prompt string
model string
)

func init() {
flag.StringVar(&prompt, "prompt", "what is life", "input prompt")
flag.StringVar(&model, "model", "", "model name")
}

func main() {
flag.Parse()

if model == "" {
log.Fatal("missing ollama model")
}

c := ollama.NewClient()

embReq := &ollama.EmbeddingRequest{
Prompt: prompt,
Model: model,
}

embs, err := c.Embed(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
61 changes: 61 additions & 0 deletions ollama/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package ollama

import (
"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/client"
)

const (
// BaseURL is Ollama HTTP API embeddings base URL.
BaseURL = "http://localhost:11434/api"
)

// Client is an OpenAI HTTP API client.
type Client struct {
opts Options
}

type Options struct {
BaseURL string
HTTPClient *client.HTTP
}

// Option is functional graph option.
type Option func(*Options)

// NewClient creates a new Ollama HTTP API client and returns it.
// You can override the default options via the client methods.
func NewClient(opts ...Option) *Client {
options := Options{
BaseURL: BaseURL,
HTTPClient: client.NewHTTP(),
}

for _, apply := range opts {
apply(&options)
}

return &Client{
opts: options,
}
}

// NewEmbedder creates a client that implements embeddings.Embedder
func NewEmbedder(opts ...Option) embeddings.Embedder[*EmbeddingRequest] {
return NewClient(opts...)
}

// WithBaseURL sets the API base URL.
func WithBaseURL(baseURL string) Option {
return func(o *Options) {
o.BaseURL = baseURL
}
}

// WithVersion sets the API version.
// WithHTTPClient sets the HTTP client.
func WithHTTPClient(httpClient *client.HTTP) Option {
return func(o *Options) {
o.HTTPClient = httpClient
}
}
28 changes: 28 additions & 0 deletions ollama/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ollama

import (
"testing"

"github.com/milosgajdos/go-embeddings/client"
"github.com/stretchr/testify/assert"
)

func TestClient(t *testing.T) {
t.Run("BaseURL", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.opts.BaseURL, BaseURL)

testVal := "http://foo"
c = NewClient(WithBaseURL(testVal))
assert.Equal(t, c.opts.BaseURL, testVal)
})

t.Run("http client", func(t *testing.T) {
c := NewClient()
assert.NotNil(t, c.opts.HTTPClient)

testVal := client.NewHTTP()
c = NewClient(WithHTTPClient(testVal))
assert.NotNil(t, c.opts.HTTPClient)
})
}
67 changes: 67 additions & 0 deletions ollama/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package ollama

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/url"

"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/request"
)

// EmbeddingRequest is serialized and sent to the API server.
type EmbeddingRequest struct {
Prompt any `json:"prompt"`
Model string `json:"model"`
}

// EmbedddingResponse received from API.
type EmbedddingResponse struct {
Embedding []float64 `json:"embedding"`
}

// ToEmbeddings converts the API response,
// into a slice of embeddings and returns it.
func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) {
floats := make([]float64, len(e.Embedding))
copy(floats, e.Embedding)
return []*embeddings.Embedding{
{Vector: floats},
}, nil
}

// Embed returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embed(ctx context.Context, embReq *EmbeddingRequest) ([]*embeddings.Embedding, error) {
u, err := url.Parse(c.opts.BaseURL + "/embeddings")
if err != nil {
return nil, err
}

var body = &bytes.Buffer{}
enc := json.NewEncoder(body)
enc.SetEscapeHTML(false)
if err := enc.Encode(embReq); err != nil {
return nil, err
}

options := []request.Option{}
req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...)
if err != nil {
return nil, err
}

resp, err := request.Do[APIError](c.opts.HTTPClient, req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

e := new(EmbedddingResponse)
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
return nil, err
}

return e.ToEmbeddings()
}
17 changes: 17 additions & 0 deletions ollama/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ollama

import "encoding/json"

// APIError is Ollama API error.
type APIError struct {
ErrorMessage string `json:"error"`
}

// Error implements errors interface.
func (e APIError) Error() string {
b, err := json.Marshal(e)
if err != nil {
return "unknown error"
}
return string(b)
}
2 changes: 1 addition & 1 deletion openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Options struct {
// Option is functional graph option.
type Option func(*Options)

// NewClient creates a new HTTP API client and returns it.
// NewClient creates a new OpenAI HTTP API client and returns it.
// By default it reads the OpenAI API key from OPENAI_API_KEY
// env var and uses the default Go http.Client for making API requests.
// You can override the default options via the client methods.
Expand Down
Loading