From aead741d6601d0e53f8f1c1b168b788dd57d902f Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Sat, 2 Dec 2023 16:37:49 +0000 Subject: [PATCH] feat: add support for multimodal embeddings in vertex (#10) Signed-off-by: Milos Gajdos --- vertexai/multi_embedding.go | 95 +++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 vertexai/multi_embedding.go diff --git a/vertexai/multi_embedding.go b/vertexai/multi_embedding.go new file mode 100644 index 0000000..5795fc2 --- /dev/null +++ b/vertexai/multi_embedding.go @@ -0,0 +1,95 @@ +package vertexai + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + + "github.com/milosgajdos/go-embeddings/request" +) + +// MultiEmbeddingRequest is multimodal embedding request. +type MultiEmbeddingRequest struct { + Instances []MultiInstance `json:"instances"` + Params *MultiParams `json:"parameters,omitempty"` +} + +// MultiInstance contains the request payload. +type MultiInstance struct { + Text *string `json:"text,omitempty"` + Image any `json:"image,omitempty"` +} + +// ImageGCS contains GCS URI to image file. +type ImageGCS struct { + URI string `json:"gcsUri"` +} + +// ImageBase64 contains image encoded as base64 string. +type ImageBase64 struct { + Bytes string `json:"bytesBase64Encoded"` +} + +// MultiParams are additional API request parameters. +type MultiParams struct { + Dimension uint `json:"dimension"` +} + +// MultiEmbedddingResponse received from API. +type MultiEmbedddingResponse struct { + Predictions []MultiPrediction `json:"predictions"` + ModelID string `json:"deployedModelId"` +} + +// MultiPrediction for a given request. +type MultiPrediction struct { + Image []float64 `json:"imageEmbedding"` + Text []float64 `json:"textEmbedding"` +} + +// MultiEmbeddings returns multimodal embeddings for every object in EmbeddingRequest. +func (c *Client) MultiEmbeddings(ctx context.Context, embReq *MultiEmbeddingRequest) (*MultiEmbedddingResponse, error) { + u, err := url.Parse(c.baseURL + "/" + c.projectID + "/" + ModelURI + "/" + c.modelID + EmbedAction) + if err != nil { + return nil, err + } + + var body = &bytes.Buffer{} + enc := json.NewEncoder(body) + enc.SetEscapeHTML(false) + if err := enc.Encode(embReq); err != nil { + return nil, err + } + + if c.token == "" { + var err error + c.token, err = GetToken(c.tokenSrc) + if err != nil { + return nil, err + } + } + + options := []request.Option{ + request.WithBearer(c.token), + } + + req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...) + if err != nil { + return nil, err + } + + resp, err := c.doRequest(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + e := new(MultiEmbedddingResponse) + if err := json.NewDecoder(resp.Body).Decode(e); err != nil { + return nil, err + } + + return e, nil +}