Skip to content

Commit

Permalink
Add openai embeddings (#1)
Browse files Browse the repository at this point in the history
Add OpenAI API client

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Nov 29, 2023
1 parent a0156bd commit 8448691
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 0 deletions.
44 changes: 44 additions & 0 deletions cmd/openai/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/openai"
)

var (
input string
model string
encoding string
)

func init() {
flag.StringVar(&input, "input", "", "input data")
flag.StringVar(&model, "model", string(openai.TextAdaV2), "model name")
flag.StringVar(&encoding, "encoding", string(openai.EncodingFloat), "encoding format")
}

func main() {
flag.Parse()

c, err := openai.NewClient()
if err != nil {
log.Fatal(err)
}

embReq := &openai.EmbeddingRequest{
Input: input,
Model: openai.Model(model),
EncodingFormat: openai.EncodingFormat(encoding),
}

embs, err := c.Embeddings(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
138 changes: 138 additions & 0 deletions openai/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package openai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
)

const (
// BaseURL is OpenAI HTTP API base URL.
BaseURL = "https://api.openai.com/v1"
// Org header
OrgHeader = "OpenAI-Organization"
)

// Client is OpenAI HTTP API client.
type Client struct {
apiKey string
baseURL string
orgID string
hc *http.Client
}

// NewClient creates a new HTTP client and returns it.
// It reads the OpenAI API key from OPENAI_API_KEY env var
// and uses the default Go http.Client.
// You can override the default options by using the
// client methods.
func NewClient() (*Client, error) {
return &Client{
apiKey: os.Getenv("OPENAI_API_KEY"),
baseURL: BaseURL,
orgID: "",
hc: &http.Client{},
}, nil
}

// WithAPIKey sets the API key.
func (c *Client) WithAPIKey(apiKey string) *Client {
c.apiKey = apiKey
return c
}

// WithBaseURL sets the API base URL.
func (c *Client) WithBaseURL(baseURL string) *Client {
c.baseURL = baseURL
return c
}

// WithOrgID sets the organization ID.
func (c *Client) WithOrgID(orgID string) *Client {
c.orgID = orgID
return c
}

// WithHTTPClient sets the HTTP client.
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
c.hc = httpClient
return c
}

// ReqOption is http requestion functional option.
type ReqOption func(*http.Request)

// WithSetHeader sets the header key to value val.
func WithSetHeader(key, val string) ReqOption {
return func(req *http.Request) {
if req.Header == nil {
req.Header = make(http.Header)
}
req.Header.Set(key, val)
}
}

// WithAddHeader adds the val to key header.
func WithAddHeader(key, val string) ReqOption {
return func(req *http.Request) {
if req.Header == nil {
req.Header = make(http.Header)
}
req.Header.Add(key, val)
}
}

func (c *Client) newRequest(ctx context.Context, method, uri string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
if ctx == nil {
ctx = context.Background()
}
if body == nil {
body = &bytes.Reader{}
}

req, err := http.NewRequestWithContext(ctx, method, uri, body)
if err != nil {
return nil, err
}

for _, setOption := range opts {
setOption(req)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
if len(c.orgID) != 0 {
req.Header.Set("OpenAI-Organization", c.orgID)
}

req.Header.Set("Accept", "application/json; charset=utf-8")
if body != nil {
// if no content-type is specified we default to json
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
}

return req, nil
}

func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
resp, err := c.hc.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
return resp, nil
}
defer resp.Body.Close()

var apiErr APIError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return nil, err
}

return nil, apiErr
}
146 changes: 146 additions & 0 deletions openai/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package openai

import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
)

// Usage tracks API token usage.
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}

// Embedding is openai API vector embedding.
type Embedding struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}

// EmbeddingString is base64 encoded embedding.
type EmbeddingString string

func (s EmbeddingString) Decode() ([]float64, error) {
decoded, err := base64.StdEncoding.DecodeString(string(s))
if err != nil {
return nil, err
}

if len(decoded)%8 != 0 {
return nil, fmt.Errorf("invalid base64 encoded string length")
}

floats := make([]float64, len(decoded)/8)

for i := 0; i < len(floats); i++ {
bits := binary.LittleEndian.Uint64(decoded[i*8 : (i+1)*8])
floats[i] = math.Float64frombits(bits)
}

return floats, nil
}

// EmbeddingRequest is serialized and sent to the API server.
type EmbeddingRequest struct {
Input any `json:"input"`
Model Model `json:"model"`
User string `json:"user"`
EncodingFormat EncodingFormat `json:"encoding_format,omitempty"`
}

// Data is used for deserializing response data.
type Data[T any] struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding T `json:"embedding"`
}

// EmbeddingResponse is the API response from a Create embeddings request.
type EmbeddingResponse[T any] struct {
Object string `json:"object"`
Data []Data[T] `json:"data"`
Model Model `json:"model"`
Usage Usage `json:"usage"`
}

func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
data := new(T)
if err := json.NewDecoder(resp).Decode(data); err != nil {
return nil, err
}

switch e := any(data).(type) {
case *EmbeddingResponse[EmbeddingString]:
embs := make([]*Embedding, 0, len(e.Data))
for _, d := range e.Data {
floats, err := d.Embedding.Decode()
if err != nil {
return nil, err
}
emb := &Embedding{
Object: d.Object,
Index: d.Index,
Embedding: floats,
}
embs = append(embs, emb)
}
return embs, nil
case *EmbeddingResponse[[]float64]:
embs := make([]*Embedding, 0, len(e.Data))
for _, d := range e.Data {
emb := &Embedding{
Object: d.Object,
Index: d.Index,
Embedding: d.Embedding,
}
embs = append(embs, emb)
}
return embs, nil
}

return nil, ErrInValidData
}

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
u, err := url.Parse(c.baseURL + "/embeddings")
if err != nil {
return nil, err
}

var body = &bytes.Buffer{}
enc := json.NewEncoder(body)
enc.SetEscapeHTML(false)
if err := enc.Encode(embReq); err != nil {
return nil, err
}

req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
if err != nil {
return nil, err
}

resp, err := c.doRequest(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

switch embReq.EncodingFormat {
case EncodingBase64:
return ToEmbeddings[EmbeddingResponse[EmbeddingString]](resp.Body)
case EncodingFloat:
return ToEmbeddings[EmbeddingResponse[[]float64]](resp.Body)
}

return nil, fmt.Errorf("unknown encoding: %v", embReq.EncodingFormat)
}
27 changes: 27 additions & 0 deletions openai/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package openai

import (
"encoding/json"
"errors"
)

var (
ErrInValidData = errors.New("invalid data")
)

type APIError struct {
Err struct {
Message string `json:"message"`
Type string `json:"type"`
Param *string `json:"param,omitempty"`
Code any `json:"code,omitempty"`
} `json:"error"`
}

func (e APIError) Error() string {
b, err := json.Marshal(e)
if err != nil {
return "unknown error"
}
return string(b)
}
16 changes: 16 additions & 0 deletions openai/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package openai

// Model is embedding model.
type Model string

const (
TextAdaV2 Model = "text-embedding-ada-002"
)

// EncodingFormat for embedding API requests.
type EncodingFormat string

const (
EncodingFloat EncodingFormat = "float"
EncodingBase64 EncodingFormat = "base64"
)

0 comments on commit 8448691

Please sign in to comment.