Skip to content

Commit

Permalink
refactor: ToEmbeddings is now a method on API response models (#12)
Browse files Browse the repository at this point in the history
Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos authored Dec 20, 2023
1 parent 7632334 commit 3f8a97b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 45 deletions.
2 changes: 1 addition & 1 deletion cmd/cohere/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func main() {
log.Fatal(err)
}

embs, err := cohere.ToEmbeddings(embResp)
embs, err := embResp.ToEmbeddings()
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func main() {
log.Fatal(err)
}

embs, err := openai.ToEmbeddings(embResp)
embs, err := embResp.ToEmbeddings()
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/vertexai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func main() {
log.Fatal(err)
}

embs, err := vertexai.ToEmbeddings(embResp)
embs, err := embResp.ToEmbeddings()
if err != nil {
log.Fatal(err)
}
Expand Down
26 changes: 13 additions & 13 deletions cohere/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,9 @@ type EmbedddingResponse struct {
Meta *Meta `json:"meta,omitempty"`
}

// Meta stores API response metadata.
type Meta struct {
APIVersion *APIVersion `json:"api_version,omitempty"`
}

// APIVersion stores metadata API version.
type APIVersion struct {
Version string `json:"version"`
}

// ToEmbeddings converts the raw API response,
// parses it into a slice of embeddings and returns it.
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
// ToEmbeddings converts the API response,
// into a slice of embeddings and returns it.
func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) {
embs := make([]*embeddings.Embedding, 0, len(e.Embeddings))
for _, e := range e.Embeddings {
floats := make([]float64, len(e))
Expand All @@ -50,6 +40,16 @@ func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
return embs, nil
}

// Meta stores API response metadata.
type Meta struct {
APIVersion *APIVersion `json:"api_version,omitempty"`
}

// APIVersion stores metadata API version.
type APIVersion struct {
Version string `json:"version"`
}

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
u, err := url.Parse(c.baseURL + "/" + c.version + "/embed")
Expand Down
30 changes: 15 additions & 15 deletions openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ type EmbeddingResponse struct {
Usage Usage `json:"usage"`
}

// ToEmbeddings converts the API response,
// into a slice of embeddings and returns it.
func (e *EmbeddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) {
embs := make([]*embeddings.Embedding, 0, len(e.Data))
for _, d := range e.Data {
floats := make([]float64, len(d.Embedding))
copy(floats, d.Embedding)
emb := &embeddings.Embedding{
Vector: floats,
}
embs = append(embs, emb)
}
return embs, nil
}

// EmbeddingRequest is serialized and sent to the API server.
type EmbeddingRequest struct {
Input any `json:"input"`
Expand All @@ -84,21 +99,6 @@ type EmbeddingResponseGen[T any] struct {
Usage Usage `json:"usage"`
}

// ToEmbeddings converts the raw API response,
// parses it into a slice of embeddings and returns it.
func ToEmbeddings(e *EmbeddingResponse) ([]*embeddings.Embedding, error) {
embs := make([]*embeddings.Embedding, 0, len(e.Data))
for _, d := range e.Data {
floats := make([]float64, len(d.Embedding))
copy(floats, d.Embedding)
emb := &embeddings.Embedding{
Vector: floats,
}
embs = append(embs, emb)
}
return embs, nil
}

// toEmbeddingResp decodes the raw API response,
// parses it into a slice of embeddings and returns it.
func toEmbeddingResp[T any](resp io.Reader) (*EmbeddingResponse, error) {
Expand Down
29 changes: 15 additions & 14 deletions vertexai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ type EmbedddingResponse struct {
Metadata map[string]any `json:"metadata"`
}

// ToEmbeddings converts the API response,
// into a slice of embeddings and returns it.
func (e *EmbedddingResponse) ToEmbeddings() ([]*embeddings.Embedding, error) {
embs := make([]*embeddings.Embedding, 0, len(e.Predictions))
for _, p := range e.Predictions {
floats := make([]float64, len(p.Embeddings.Values))
copy(floats, p.Embeddings.Values)
emb := &embeddings.Embedding{
Vector: floats,
}
embs = append(embs, emb)
}
return embs, nil
}

// Predictions is the generated response
type Predictions struct {
Embeddings struct {
Expand All @@ -54,20 +69,6 @@ type Statistics struct {
Truncated bool `json:"truncated"`
}

// ToEmbeddings converts the API response to embeddings object.
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
embs := make([]*embeddings.Embedding, 0, len(e.Predictions))
for _, p := range e.Predictions {
floats := make([]float64, len(p.Embeddings.Values))
copy(floats, p.Embeddings.Values)
emb := &embeddings.Embedding{
Vector: floats,
}
embs = append(embs, emb)
}
return embs, nil
}

// Embeddings returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
u, err := url.Parse(c.baseURL + "/" + c.projectID + "/" + ModelURI + "/" + c.modelID + EmbedAction)
Expand Down

0 comments on commit 3f8a97b

Please sign in to comment.