Skip to content

Commit

Permalink
Add hugging face API but it's a total mess
Browse files Browse the repository at this point in the history
Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos committed Nov 30, 2023
1 parent 23f68bc commit 2c5465e
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 0 deletions.
49 changes: 49 additions & 0 deletions cmd/huggingface/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/cohere"
"github.com/milosgajdos/go-embeddings/huggingface"
)

var (
input string
model string
wait bool
)

func init() {
flag.StringVar(&input, "input", "what is life", "input data")
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name")
flag.BoolVar(&wait, "wait", false, "wait for model to start")
}

func main() {
flag.Parse()

c := huggingface.NewClient().
WithModel(model)

embReq := &huggingface.EmbeddingRequest{
Inputs: []string{input},
Options: huggingface.Options{
WaitForModel: &wait,
},
}

embResp, err := c.Embeddings(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

embs, err := huggingface.ToEmbeddings(embResp)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
74 changes: 74 additions & 0 deletions huggingface/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package huggingface

import (
"encoding/json"
"net/http"
"os"
)

const (
// BaseURL is Cohere HTTP API base URL.
BaseURL = "https://api-inference.huggingface.co/models"
)

// Client is Cohere HTTP API client.
type Client struct {
apiKey string
baseURL string
model string
hc *http.Client
}

// NewClient creates a new HTTP API client and returns it.
// By default it reads the Cohere API key from HUGGINGFACE_API_KEY
// env var and uses the default Go http.Client for making API requests.
// You can override the default options via the client methods.
func NewClient() *Client {
return &Client{
apiKey: os.Getenv("HUGGINGFACE_API_KEY"),
baseURL: BaseURL,
hc: &http.Client{},
}
}

// WithAPIKey sets the API key.
func (c *Client) WithAPIKey(apiKey string) *Client {
c.apiKey = apiKey
return c
}

// WithBaseURL sets the API base URL.
func (c *Client) WithBaseURL(baseURL string) *Client {
c.baseURL = baseURL
return c
}

// WithModel sets the model name
func (c *Client) WithModel(model string) *Client {
c.model = model
return c
}

// WithHTTPClient sets the HTTP client.
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
c.hc = httpClient
return c
}

func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
resp, err := c.hc.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
return resp, nil
}
defer resp.Body.Close()

var apiErr APIError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return nil, err
}

return nil, apiErr
}
52 changes: 52 additions & 0 deletions huggingface/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package huggingface

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

const (
huggingFaceKey = "somekey"
)

func TestClient(t *testing.T) {
t.Setenv("HUGGINGFACE_API_KEY", huggingFaceKey)

t.Run("API key", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.apiKey, huggingFaceKey)

testVal := "foo"
c.WithAPIKey(testVal)
assert.Equal(t, c.apiKey, testVal)
})

t.Run("BaseURL", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.baseURL, BaseURL)

testVal := "http://foo"
c.WithBaseURL(testVal)
assert.Equal(t, c.baseURL, testVal)
})

t.Run("Model", func(t *testing.T) {
c := NewClient()
assert.Equal(t, c.model, "")

testVal := "foo/bar"
c.WithModel(testVal)
assert.Equal(t, c.model, testVal)
})

t.Run("http client", func(t *testing.T) {
c := NewClient()
assert.NotNil(t, c.hc)

testVal := &http.Client{}
c.WithHTTPClient(testVal)
assert.NotNil(t, c.hc)
})
}
81 changes: 81 additions & 0 deletions huggingface/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package huggingface

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/url"

"github.com/milosgajdos/go-embeddings"
"github.com/milosgajdos/go-embeddings/request"
)

// EmbeddingRequest sent to API endpoint.
type EmbeddingRequest struct {
Inputs []string `json:"inputs"`
Options Options `json:"options,omitempty"`
}

// Options
type Options struct {
WaitForModel *bool `json:"wait_for_model,omitempty"`
}

// EmbedddingResponse is returned by API.
// TODO: hugging face APIs are a mess
type EmbedddingResponse [][][][]float64

// ToEmbeddings converts the raw API response,
// parses it into a slice of embeddings and returns it.
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
emb := *e
embs := make([]*embeddings.Embedding, 0, len(emb))
//for i := range emb {
// vals := emb[i]
// floats := make([]float64, len(vals))
// copy(floats, vals)
// emb := &embeddings.Embedding{
// Vector: floats,
// }
// embs = append(embs, emb)
//}
return embs, nil
}

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
u, err := url.Parse(c.baseURL + "/" + c.model)
if err != nil {
return nil, err
}

var body = &bytes.Buffer{}
enc := json.NewEncoder(body)
enc.SetEscapeHTML(false)
if err := enc.Encode(embReq); err != nil {
return nil, err
}

options := []request.Option{
request.WithBearer(c.apiKey),
}

req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...)
if err != nil {
return nil, err
}

resp, err := c.doRequest(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

e := new(EmbedddingResponse)
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
return nil, err
}

return e, nil
}
17 changes: 17 additions & 0 deletions huggingface/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package huggingface

import "encoding/json"

// APIError is error returned by API
type APIError struct {
Message string `json:"error"`
}

// Error implements errors interface.
func (e APIError) Error() string {
b, err := json.Marshal(e)
if err != nil {
return "unknown error"
}
return string(b)
}

0 comments on commit 2c5465e

Please sign in to comment.