Skip to content

Commit

Permalink
feat: add support for AWS bedrock embedding models (#27)
Browse files Browse the repository at this point in the history
We've also fixed some tests and godoc comments

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos committed May 18, 2024
1 parent b1ee5f2 commit a501b98
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 14 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ Here's a list of the env vars for each supported client

* `VOYAGE_API_KEY`: Voyage AI API key

### AWS Bedrock

> [!IMPORTANT]
> You must enable access to Bedrock embedding models
> See here: [https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access)
* `AWS_REGION`: AWS region

Usual AWS env vars as read by the AWS SDKs i.e. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc.

## nix

The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds.
Expand Down
14 changes: 14 additions & 0 deletions bedrock/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package bedrock

// Model is embedding model.
type Model string

const (
TitanTextV1 Model = "amazon.titan-embed-text-v1"
TitanTextV2 Model = "amazon.titan-embed-text-v2:0"
)

// String implements stringer.
func (m Model) String() string {
return string(m)
}
80 changes: 80 additions & 0 deletions bedrock/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package bedrock

import (
"context"
"log"
"os"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
)

const (
// DefaultRegion is default AWS region
DefaultRegion = "us-east-1"
)

type Client struct {
opts Options
}

type Options struct {
Region string
ModelID string
Client *bedrockruntime.Client
}

// Option is functional option.
type Option func(*Options)

// NewClient creates a new AWS Bedrock HTTP API client and returns it.
// By default it reads the default AWS evnironment variables.
// and constructs the AWS API client.
func NewClient(opts ...Option) *Client {
options := Options{
Region: os.Getenv("AWS_REGION"),
ModelID: os.Getenv("AWS_BEDROCK_MODEL_ID"),
}

for _, apply := range opts {
apply(&options)
}

if options.Region == "" {
options.Region = DefaultRegion
}

if options.Client == nil {
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(options.Region))
if err != nil {
log.Fatal(err)
}

options.Client = bedrockruntime.NewFromConfig(cfg)
}

return &Client{
opts: options,
}
}

// WithRegion sets AWS region.
func WithRegion(region string) Option {
return func(o *Options) {
o.Region = region
}
}

// WithModelID sets the Tital embedding model ID>
func WithModelID(id string) Option {
return func(o *Options) {
o.ModelID = id
}
}

// WithBedrockClient sets Bedrock API client.
func WithBedrockClient(client *bedrockruntime.Client) Option {
return func(o *Options) {
o.Client = client
}
}
53 changes: 53 additions & 0 deletions bedrock/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package bedrock

import (
"context"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/stretchr/testify/assert"
)

const (
bedrockModelID = "model"
)

func TestClient(t *testing.T) {
t.Parallel()

t.Run("Region", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.Equal(t, c.opts.Region, DefaultRegion)

testVal := "us-west-1"
c = NewClient(WithRegion(testVal))
assert.Equal(t, c.opts.Region, testVal)
})

t.Run("ModelID", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.Equal(t, c.opts.ModelID, "")

c = NewClient(WithModelID(bedrockModelID))
assert.Equal(t, c.opts.ModelID, bedrockModelID)
})

t.Run("BedrockClient", func(t *testing.T) {
t.Parallel()
c := NewClient()
assert.NotNil(t, c.opts.Client)
assert.Equal(t, c.opts.Client.Options().Region, DefaultRegion)

testVal := "us-west-1"
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(testVal))
assert.NoError(t, err)
bc := bedrockruntime.NewFromConfig(cfg)

c = NewClient(WithBedrockClient(bc))
assert.NotNil(t, c.opts.Client)
assert.Equal(t, c.opts.Client.Options().Region, testVal)
})
}
51 changes: 51 additions & 0 deletions bedrock/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package bedrock

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/milosgajdos/go-embeddings"
)

type Request struct {
InputText string `json:"inputText"`
}

type Response struct {
Embedding []float64 `json:"embedding"`
InputTextTokenCount int `json:"inputTextTokenCount"`
}

func (e *Response) ToEmbeddings() ([]*embeddings.Embedding, error) {
vals := make([]float64, len(e.Embedding))
copy(vals, e.Embedding)
return []*embeddings.Embedding{
{Vector: vals},
}, nil
}

// Embed returns embeddings for every object in EmbeddingRequest.
func (c *Client) Embed(ctx context.Context, embReq *Request) ([]*embeddings.Embedding, error) {
payload, err := json.Marshal(embReq)
if err != nil {
return nil, err
}

resp, err := c.opts.Client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
Body: payload,
ModelId: aws.String(c.opts.ModelID),
ContentType: aws.String("application/json"),
})
if err != nil {
return nil, err
}

var embs Response
if err = json.Unmarshal(resp.Body, &embs); err != nil {
return nil, nil
}

return embs.ToEmbeddings()
}
37 changes: 37 additions & 0 deletions cmd/bedrock/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"context"
"flag"
"fmt"
"log"

"github.com/milosgajdos/go-embeddings/bedrock"
)

var (
input string
model string
)

func init() {
flag.StringVar(&input, "input", "what is life", "input data")
flag.StringVar(&model, "model", bedrock.TitanTextV1.String(), "model name")
}

func main() {
flag.Parse()

c := bedrock.NewClient(bedrock.WithModelID(model))

embReq := &bedrock.Request{
InputText: input,
}

embs, err := c.Embed(context.Background(), embReq)
if err != nil {
log.Fatal(err)
}

fmt.Printf("got %d embeddings", len(embs))
}
2 changes: 1 addition & 1 deletion cohere/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Options struct {
HTTPClient *client.HTTP
}

// Option is functional graph option.
// Option is functional option.
type Option func(*Options)

// NewClient creates a new HTTP API client and returns it.
Expand Down
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,28 @@ module github.com/milosgajdos/go-embeddings
go 1.20

require (
github.com/aws/aws-sdk-go-v2 v1.27.0
github.com/aws/aws-sdk-go-v2/config v1.27.15
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
github.com/stretchr/testify v1.8.4
golang.org/x/oauth2 v0.15.0
)

require (
cloud.google.com/go/compute v1.20.1 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
30 changes: 30 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,36 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/config v1.27.15 h1:uNnGLZ+DutuNEkuPh6fwqK7LpEiPmzb7MIMA1mNWEUc=
github.com/aws/aws-sdk-go-v2/config v1.27.15/go.mod h1:7j7Kxx9/7kTmL7z4LlhwQe63MYEE5vkVV6nWg4ZAI8M=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 h1:dQLK4TjtnlRGb0czOht2CevZ5l6RSyRWAnKeGd7VAFE=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3/go.mod h1:TL79f2P6+8Q7dTsILpiVST+AL9lkF6PPGI167Ny0Cjw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 h1:Wx0rlZoEJR7JwlSZcHnEa7CNjrSIyVxMFWGAaXy4fJY=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9/go.mod h1:aVMHdE0aHO3v+f/iw01fmXV/5DbfQ3Bi9nN7nd9bE9Y=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 h1:Kv1hwNG6jHC/sxMTe5saMjH6t6ZLkgfvVxyEjfWL1ks=
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8/go.mod h1:c1qtZUWtygI6ZdvKppzCSXsDOq5I4luJPZ0Ud3juFCA=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 h1:nWBZ1xHCF+A7vv9sDzJOq4NWIdzFYm0kH7Pr4OjHYsQ=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2/go.mod h1:9lmoVDVLz/yUZwLaQ676TK02fhCu4+PgRSmMaKR1ozk=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 h1:Qp6Boy0cGDloOE3zI6XhNLNZgjNS8YmiFQFHe71SaW0=
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9/go.mod h1:0Aqn1MnEuitqfsCNyKsdKLhDUOr4txD/g19EfiUqgws=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down
Loading

0 comments on commit a501b98

Please sign in to comment.