Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for AWS bedrock embedding models #27

Merged
merged 1 commit into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ Here's a list of the env vars for each supported client

* `VOYAGE_API_KEY`: Voyage AI API key

### AWS Bedrock

> [!IMPORTANT]
> You must enable access to Bedrock embedding models
> See here: [https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access)

* `AWS_REGION`: AWS region

Usual AWS env vars as read by the AWS SDKs i.e. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc.

## nix

The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds.
Expand Down
14 changes: 14 additions & 0 deletions bedrock/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package bedrock

// Model is embedding model.
type Model string

const (
TitanTextV1 Model = "amazon.titan-embed-text-v1"
TitanTextV2 Model = "amazon.titan-embed-text-v2:0"
)

// String implements stringer.
func (m Model) String() string {
return string(m)
}
80 changes: 80 additions & 0 deletions bedrock/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package bedrock

import (
"context"
"log"
"os"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

const (
// DefaultRegion is default AWS region
DefaultRegion = "us-east-1"
)

type Client struct {
opts Options
}

type Options struct {
Region string
ModelID string
Client *bedrockruntime.Client
}

// Option is functional option.
type Option func(*Options)

// NewClient creates a new AWS Bedrock HTTP API client and returns it.
// By default it reads the default AWS evnironment variables.
// and constructs the AWS API client.
func NewClient(opts ...Option) *Client {
options := Options{
Region: os.Getenv("AWS_REGION"),
ModelID: os.Getenv("AWS_BEDROCK_MODEL_ID"),
}

for _, apply := range opts {
apply(&options)
}

if options.Region == "" {
options.Region = DefaultRegion
}

if options.Client == nil {
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(options.Region))
if err != nil {
log.Fatal(err)
}

options.Client = bedrockruntime.NewFromConfig(cfg)
}

return &Client{
opts: options,
}
}

// WithRegion sets AWS region.
func WithRegion(region string) Option {
return func(o *Options) {
o.Region = region
}
}

// WithModelID sets the Tital embedding model ID>
func WithModelID(id string) Option {
return func(o *Options) {
o.ModelID = id
}
}

// WithBedrockClient sets Bedrock API client.
func WithBedrockClient(client *bedrockruntime.Client) Option {
return func(o *Options) {
o.Client = client
}
}
53 changes: 53 additions & 0 deletions bedrock/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package bedrock

import (
"context"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/stretchr/testify/assert"
)

const (
bedrockModelID = "model"
)

func TestClient(t *testing.T) {
t.Parallel()

t.Run("Region", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.Equal(t, c.opts.Region, DefaultRegion)

testVal := "us-west-1"
c = NewClient(WithRegion(testVal))
assert.Equal(t, c.opts.Region, testVal)
})

t.Run("ModelID", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.Equal(t, c.opts.ModelID, "")

c = NewClient(WithModelID(bedrockModelID))
assert.Equal(t, c.opts.ModelID, bedrockModelID)
})

t.Run("BedrockClient", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.NotNil(t, c.opts.Client)
assert.Equal(t, c.opts.Client.Options().Region, DefaultRegion)

testVal := "us-west-1"
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(testVal))
assert.NoError(t, err)
bc := bedrockruntime.NewFromConfig(cfg)

c = NewClient(WithBedrockClient(bc))
assert.NotNil(t, c.opts.Client)
assert.Equal(t, c.opts.Client.Options().Region, testVal)
})
}
51 changes: 51 additions & 0 deletions bedrock/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package bedrock

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/milosgajdos/go-embeddings"
)

type Request struct {
InputText string `json:"inputText"`
}

type Response struct {
Embedding []float64 `json:"embedding"`
InputTextTokenCount int `json:"inputTextTokenCount"`
}

func (e *Response) ToEmbeddings() ([]*embeddings.Embedding, error) {
vals := make([]float64, len(e.Embedding))
copy(vals, e.Embedding)
return []*embeddings.Embedding{
{Vector: vals},
}, nil
}

// Embed returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embed(ctx context.Context, embReq *Request) ([]*embeddings.Embedding, error) {
payload, err := json.Marshal(embReq)
if err != nil {
return nil, err
}

resp, err := c.opts.Client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
Body: payload,
ModelId: aws.String(c.opts.ModelID),
ContentType: aws.String("application/json"),
})
if err != nil {
return nil, err
}

var embs Response
if err = json.Unmarshal(resp.Body, &embs); err != nil {
return nil, nil
}

return embs.ToEmbeddings()
}
37 changes: 37 additions & 0 deletions cmd/bedrock/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/bedrock"
)

var (
input string
model string
)

func init() {
flag.StringVar(&input, "input", "what is life", "input data")
flag.StringVar(&model, "model", bedrock.TitanTextV1.String(), "model name")
}

func main() {
flag.Parse()

c := bedrock.NewClient(bedrock.WithModelID(model))

embReq := &bedrock.Request{
InputText: input,
}

embs, err := c.Embed(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
2 changes: 1 addition & 1 deletion cohere/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Options struct {
HTTPClient *client.HTTP
}

// Option is functional graph option.
// Option is functional option.
type Option func(*Options)

// NewClient creates a new HTTP API client and returns it.
Expand Down
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,28 @@ module github.com/milosgajdos/go-embeddings
go 1.20

require (
github.com/aws/aws-sdk-go-v2 v1.27.0
github.com/aws/aws-sdk-go-v2/config v1.27.15
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
github.com/stretchr/testify v1.8.4
golang.org/x/oauth2 v0.15.0
)

require (
cloud.google.com/go/compute v1.20.1 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
30 changes: 30 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,36 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/config v1.27.15 h1:uNnGLZ+DutuNEkuPh6fwqK7LpEiPmzb7MIMA1mNWEUc=
github.com/aws/aws-sdk-go-v2/config v1.27.15/go.mod h1:7j7Kxx9/7kTmL7z4LlhwQe63MYEE5vkVV6nWg4ZAI8M=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 h1:dQLK4TjtnlRGb0czOht2CevZ5l6RSyRWAnKeGd7VAFE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3/go.mod h1:TL79f2P6+8Q7dTsILpiVST+AL9lkF6PPGI167Ny0Cjw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 h1:Wx0rlZoEJR7JwlSZcHnEa7CNjrSIyVxMFWGAaXy4fJY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9/go.mod h1:aVMHdE0aHO3v+f/iw01fmXV/5DbfQ3Bi9nN7nd9bE9Y=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 h1:Kv1hwNG6jHC/sxMTe5saMjH6t6ZLkgfvVxyEjfWL1ks=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8/go.mod h1:c1qtZUWtygI6ZdvKppzCSXsDOq5I4luJPZ0Ud3juFCA=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 h1:nWBZ1xHCF+A7vv9sDzJOq4NWIdzFYm0kH7Pr4OjHYsQ=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2/go.mod h1:9lmoVDVLz/yUZwLaQ676TK02fhCu4+PgRSmMaKR1ozk=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 h1:Qp6Boy0cGDloOE3zI6XhNLNZgjNS8YmiFQFHe71SaW0=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9/go.mod h1:0Aqn1MnEuitqfsCNyKsdKLhDUOr4txD/g19EfiUqgws=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down
Loading
Loading