Skip to content

Commit

Permalink
update: make embedding.Base64.Decode return Embedding (#26)
Browse files Browse the repository at this point in the history
Also, update linter and Go runtime versions

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored May 8, 2024
1 parent 3c3905b commit b1ee5f2
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
matrix:
os: [ ubuntu-latest ]
go:
- '1.20'
- '1.21'
- '1.22'

steps:

Expand Down Expand Up @@ -68,8 +68,8 @@ jobs:
matrix:
os: [ ubuntu-latest ]
go:
- '1.20'
- '1.21'
- '1.22'

steps:

Expand All @@ -84,4 +84,4 @@ jobs:
- name: Run linter
uses: golangci/golangci-lint-action@v3
with:
version: v1.55
version: v1.58
6 changes: 3 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ linters:
- gofmt
- revive
- ineffassign
- vet
- govet
- unused
- misspell
- bodyclose
Expand All @@ -22,7 +22,7 @@ linters:
- unconvert
- whitespace

run:
issues:
deadline: 2m
skip-dirs:
exlude-dirs:
- vendor
6 changes: 4 additions & 2 deletions embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (e Embedding) ToFloat32() []float32 {
type Base64 string

// Decode decodes base64 encoded string into a slice of floats.
func (s Base64) Decode() ([]float64, error) {
func (s Base64) Decode() (*Embedding, error) {
decoded, err := base64.StdEncoding.DecodeString(string(s))
if err != nil {
return nil, err
Expand All @@ -49,5 +49,7 @@ func (s Base64) Decode() ([]float64, error) {
floats[i] = math.Float64frombits(bits)
}

return floats, nil
return &Embedding{
Vector: floats,
}, nil
}
4 changes: 2 additions & 2 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func TestBase64Decode(t *testing.T) {
if tc.wantErr {
t.Fatal("expected error")
}
if !reflect.DeepEqual(got, tc.exp) {
t.Fatalf("expected: %v, got: %v", tc.exp, got)
if !reflect.DeepEqual(got.Vector, tc.exp) {
t.Fatalf("expected: %v, got: %v", tc.exp, got.Vector)
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ func toEmbeddingResp[T any](resp io.Reader) (*EmbeddingResponse, error) {
case *EmbeddingResponseGen[embeddings.Base64]:
embData := make([]Data, 0, len(e.Data))
for _, d := range e.Data {
floats, err := d.Embedding.Decode()
emb, err := d.Embedding.Decode()
if err != nil {
return nil, err
}
embData = append(embData, Data{
Object: d.Object,
Index: d.Index,
Embedding: floats,
Embedding: emb.Vector,
})
}
return &EmbeddingResponse{
Expand Down
4 changes: 2 additions & 2 deletions voyage/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ func toEmbeddingResp[T any](resp io.Reader) (*EmbeddingResponse, error) {
case *EmbeddingResponseGen[embeddings.Base64]:
embData := make([]Data, 0, len(e.Data))
for _, d := range e.Data {
floats, err := d.Embedding.Decode()
emb, err := d.Embedding.Decode()
if err != nil {
return nil, err
}
embData = append(embData, Data{
Object: d.Object,
Index: d.Index,
Embedding: floats,
Embedding: emb.Vector,
})
}
return &EmbeddingResponse{
Expand Down

0 comments on commit b1ee5f2

Please sign in to comment.