Skip to content

Commit a501b98

Browse files
authored
feat: add support for AWS bedrock embedding models (#27)
We've also fixed some tests and godoc comments Signed-off-by: Milos Gajdos <[email protected]>
1 parent b1ee5f2 commit a501b98

File tree

15 files changed

+349
-14
lines changed

15 files changed

+349
-14
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ Here's a list of the env vars for each supported client
4444

4545
* `VOYAGE_API_KEY`: Voyage AI API key
4646

47+
### AWS Bedrock
48+
49+
> [!IMPORTANT]
50+
> You must enable access to Bedrock embedding models
51+
> See here: [https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#add-model-access)
52+
53+
* `AWS_REGION`: AWS region
54+
55+
Usual AWS env vars as read by the AWS SDKs i.e. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc.
56+
4757
## nix
4858

4959
The project provides a simple `nix` flake tha leverages [gomod2nix](https://github.com/nix-community/gomod2nix) for consistent Go environments and builds.

bedrock/bedrock.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package bedrock
2+
3+
// Model is embedding model.
4+
type Model string
5+
6+
const (
7+
TitanTextV1 Model = "amazon.titan-embed-text-v1"
8+
TitanTextV2 Model = "amazon.titan-embed-text-v2:0"
9+
)
10+
11+
// String implements stringer.
12+
func (m Model) String() string {
13+
return string(m)
14+
}

bedrock/client.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package bedrock
2+
3+
import (
4+
"context"
5+
"log"
6+
"os"
7+
8+
"github.com/aws/aws-sdk-go-v2/config"
9+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
10+
)
11+
12+
const (
13+
// DefaultRegion is default AWS region
14+
DefaultRegion = "us-east-1"
15+
)
16+
17+
type Client struct {
18+
opts Options
19+
}
20+
21+
type Options struct {
22+
Region string
23+
ModelID string
24+
Client *bedrockruntime.Client
25+
}
26+
27+
// Option is functional option.
28+
type Option func(*Options)
29+
30+
// NewClient creates a new AWS Bedrock HTTP API client and returns it.
31+
// By default it reads the default AWS evnironment variables.
32+
// and constructs the AWS API client.
33+
func NewClient(opts ...Option) *Client {
34+
options := Options{
35+
Region: os.Getenv("AWS_REGION"),
36+
ModelID: os.Getenv("AWS_BEDROCK_MODEL_ID"),
37+
}
38+
39+
for _, apply := range opts {
40+
apply(&options)
41+
}
42+
43+
if options.Region == "" {
44+
options.Region = DefaultRegion
45+
}
46+
47+
if options.Client == nil {
48+
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(options.Region))
49+
if err != nil {
50+
log.Fatal(err)
51+
}
52+
53+
options.Client = bedrockruntime.NewFromConfig(cfg)
54+
}
55+
56+
return &Client{
57+
opts: options,
58+
}
59+
}
60+
61+
// WithRegion sets AWS region.
62+
func WithRegion(region string) Option {
63+
return func(o *Options) {
64+
o.Region = region
65+
}
66+
}
67+
68+
// WithModelID sets the Tital embedding model ID>
69+
func WithModelID(id string) Option {
70+
return func(o *Options) {
71+
o.ModelID = id
72+
}
73+
}
74+
75+
// WithBedrockClient sets Bedrock API client.
76+
func WithBedrockClient(client *bedrockruntime.Client) Option {
77+
return func(o *Options) {
78+
o.Client = client
79+
}
80+
}

bedrock/client_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package bedrock
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/aws/aws-sdk-go-v2/config"
8+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
const (
13+
bedrockModelID = "model"
14+
)
15+
16+
func TestClient(t *testing.T) {
17+
t.Parallel()
18+
19+
t.Run("Region", func(t *testing.T) {
20+
t.Parallel()
21+
c := NewClient()
22+
assert.Equal(t, c.opts.Region, DefaultRegion)
23+
24+
testVal := "us-west-1"
25+
c = NewClient(WithRegion(testVal))
26+
assert.Equal(t, c.opts.Region, testVal)
27+
})
28+
29+
t.Run("ModelID", func(t *testing.T) {
30+
t.Parallel()
31+
c := NewClient()
32+
assert.Equal(t, c.opts.ModelID, "")
33+
34+
c = NewClient(WithModelID(bedrockModelID))
35+
assert.Equal(t, c.opts.ModelID, bedrockModelID)
36+
})
37+
38+
t.Run("BedrockClient", func(t *testing.T) {
39+
t.Parallel()
40+
c := NewClient()
41+
assert.NotNil(t, c.opts.Client)
42+
assert.Equal(t, c.opts.Client.Options().Region, DefaultRegion)
43+
44+
testVal := "us-west-1"
45+
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(testVal))
46+
assert.NoError(t, err)
47+
bc := bedrockruntime.NewFromConfig(cfg)
48+
49+
c = NewClient(WithBedrockClient(bc))
50+
assert.NotNil(t, c.opts.Client)
51+
assert.Equal(t, c.opts.Client.Options().Region, testVal)
52+
})
53+
}

bedrock/embedding.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package bedrock
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"github.com/aws/aws-sdk-go-v2/aws"
8+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
9+
"github.com/milosgajdos/go-embeddings"
10+
)
11+
12+
type Request struct {
13+
InputText string `json:"inputText"`
14+
}
15+
16+
type Response struct {
17+
Embedding []float64 `json:"embedding"`
18+
InputTextTokenCount int `json:"inputTextTokenCount"`
19+
}
20+
21+
func (e *Response) ToEmbeddings() ([]*embeddings.Embedding, error) {
22+
vals := make([]float64, len(e.Embedding))
23+
copy(vals, e.Embedding)
24+
return []*embeddings.Embedding{
25+
{Vector: vals},
26+
}, nil
27+
}
28+
29+
// Embed returns embeddings for every object in EmbeddingRequest.
30+
func (c *Client) Embed(ctx context.Context, embReq *Request) ([]*embeddings.Embedding, error) {
31+
payload, err := json.Marshal(embReq)
32+
if err != nil {
33+
return nil, err
34+
}
35+
36+
resp, err := c.opts.Client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
37+
Body: payload,
38+
ModelId: aws.String(c.opts.ModelID),
39+
ContentType: aws.String("application/json"),
40+
})
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
var embs Response
46+
if err = json.Unmarshal(resp.Body, &embs); err != nil {
47+
return nil, nil
48+
}
49+
50+
return embs.ToEmbeddings()
51+
}

cmd/bedrock/main.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
9+
"github.com/milosgajdos/go-embeddings/bedrock"
10+
)
11+
12+
var (
13+
input string
14+
model string
15+
)
16+
17+
func init() {
18+
flag.StringVar(&input, "input", "what is life", "input data")
19+
flag.StringVar(&model, "model", bedrock.TitanTextV1.String(), "model name")
20+
}
21+
22+
func main() {
23+
flag.Parse()
24+
25+
c := bedrock.NewClient(bedrock.WithModelID(model))
26+
27+
embReq := &bedrock.Request{
28+
InputText: input,
29+
}
30+
31+
embs, err := c.Embed(context.Background(), embReq)
32+
if err != nil {
33+
log.Fatal(err)
34+
}
35+
36+
fmt.Printf("got %d embeddings", len(embs))
37+
}

cohere/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ type Options struct {
2727
HTTPClient *client.HTTP
2828
}
2929

30-
// Option is functional graph option.
30+
// Option is functional option.
3131
type Option func(*Options)
3232

3333
// NewClient creates a new HTTP API client and returns it.

go.mod

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,28 @@ module github.com/milosgajdos/go-embeddings
33
go 1.20
44

55
require (
6+
github.com/aws/aws-sdk-go-v2 v1.27.0
7+
github.com/aws/aws-sdk-go-v2/config v1.27.15
8+
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3
69
github.com/stretchr/testify v1.8.4
710
golang.org/x/oauth2 v0.15.0
811
)
912

1013
require (
1114
cloud.google.com/go/compute v1.20.1 // indirect
1215
cloud.google.com/go/compute/metadata v0.2.3 // indirect
16+
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
17+
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 // indirect
18+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect
19+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect
20+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect
21+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
22+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect
23+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect
24+
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 // indirect
25+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 // indirect
26+
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 // indirect
27+
github.com/aws/smithy-go v1.20.2 // indirect
1328
github.com/davecgh/go-spew v1.1.1 // indirect
1429
github.com/golang/protobuf v1.5.3 // indirect
1530
github.com/pmezard/go-difflib v1.0.0 // indirect

go.sum

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,36 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN
22
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
33
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
44
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
5+
github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo=
6+
github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
7+
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
8+
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
9+
github.com/aws/aws-sdk-go-v2/config v1.27.15 h1:uNnGLZ+DutuNEkuPh6fwqK7LpEiPmzb7MIMA1mNWEUc=
10+
github.com/aws/aws-sdk-go-v2/config v1.27.15/go.mod h1:7j7Kxx9/7kTmL7z4LlhwQe63MYEE5vkVV6nWg4ZAI8M=
11+
github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo=
12+
github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU=
13+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 h1:dQLK4TjtnlRGb0czOht2CevZ5l6RSyRWAnKeGd7VAFE=
14+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3/go.mod h1:TL79f2P6+8Q7dTsILpiVST+AL9lkF6PPGI167Ny0Cjw=
15+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk=
16+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI=
17+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g=
18+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI=
19+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
20+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
21+
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY=
22+
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ=
23+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs=
24+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg=
25+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 h1:Wx0rlZoEJR7JwlSZcHnEa7CNjrSIyVxMFWGAaXy4fJY=
26+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9/go.mod h1:aVMHdE0aHO3v+f/iw01fmXV/5DbfQ3Bi9nN7nd9bE9Y=
27+
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8 h1:Kv1hwNG6jHC/sxMTe5saMjH6t6ZLkgfvVxyEjfWL1ks=
28+
github.com/aws/aws-sdk-go-v2/service/sso v1.20.8/go.mod h1:c1qtZUWtygI6ZdvKppzCSXsDOq5I4luJPZ0Ud3juFCA=
29+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2 h1:nWBZ1xHCF+A7vv9sDzJOq4NWIdzFYm0kH7Pr4OjHYsQ=
30+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.2/go.mod h1:9lmoVDVLz/yUZwLaQ676TK02fhCu4+PgRSmMaKR1ozk=
31+
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9 h1:Qp6Boy0cGDloOE3zI6XhNLNZgjNS8YmiFQFHe71SaW0=
32+
github.com/aws/aws-sdk-go-v2/service/sts v1.28.9/go.mod h1:0Aqn1MnEuitqfsCNyKsdKLhDUOr4txD/g19EfiUqgws=
33+
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
34+
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
535
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
636
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
737
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=

0 commit comments

Comments
 (0)