generated from milosgajdos/go-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for AWS bedrock embedding models (#27)
We've also fixed some tests and godoc comments Signed-off-by: Milos Gajdos <[email protected]>
- Loading branch information
1 parent
b1ee5f2
commit a501b98
Showing
15 changed files
with
349 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package bedrock | ||
|
||
// Model is embedding model. | ||
type Model string | ||
|
||
const ( | ||
TitanTextV1 Model = "amazon.titan-embed-text-v1" | ||
TitanTextV2 Model = "amazon.titan-embed-text-v2:0" | ||
) | ||
|
||
// String implements stringer. | ||
func (m Model) String() string { | ||
return string(m) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"log" | ||
"os" | ||
|
||
"github.com/aws/aws-sdk-go-v2/config" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
) | ||
|
||
const ( | ||
// DefaultRegion is default AWS region | ||
DefaultRegion = "us-east-1" | ||
) | ||
|
||
type Client struct { | ||
opts Options | ||
} | ||
|
||
type Options struct { | ||
Region string | ||
ModelID string | ||
Client *bedrockruntime.Client | ||
} | ||
|
||
// Option is functional option. | ||
type Option func(*Options) | ||
|
||
// NewClient creates a new AWS Bedrock HTTP API client and returns it. | ||
// By default it reads the default AWS evnironment variables. | ||
// and constructs the AWS API client. | ||
func NewClient(opts ...Option) *Client { | ||
options := Options{ | ||
Region: os.Getenv("AWS_REGION"), | ||
ModelID: os.Getenv("AWS_BEDROCK_MODEL_ID"), | ||
} | ||
|
||
for _, apply := range opts { | ||
apply(&options) | ||
} | ||
|
||
if options.Region == "" { | ||
options.Region = DefaultRegion | ||
} | ||
|
||
if options.Client == nil { | ||
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(options.Region)) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
options.Client = bedrockruntime.NewFromConfig(cfg) | ||
} | ||
|
||
return &Client{ | ||
opts: options, | ||
} | ||
} | ||
|
||
// WithRegion sets AWS region. | ||
func WithRegion(region string) Option { | ||
return func(o *Options) { | ||
o.Region = region | ||
} | ||
} | ||
|
||
// WithModelID sets the Tital embedding model ID> | ||
func WithModelID(id string) Option { | ||
return func(o *Options) { | ||
o.ModelID = id | ||
} | ||
} | ||
|
||
// WithBedrockClient sets Bedrock API client. | ||
func WithBedrockClient(client *bedrockruntime.Client) Option { | ||
return func(o *Options) { | ||
o.Client = client | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go-v2/config" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
const ( | ||
bedrockModelID = "model" | ||
) | ||
|
||
func TestClient(t *testing.T) { | ||
t.Parallel() | ||
|
||
t.Run("Region", func(t *testing.T) { | ||
t.Parallel() | ||
c := NewClient() | ||
assert.Equal(t, c.opts.Region, DefaultRegion) | ||
|
||
testVal := "us-west-1" | ||
c = NewClient(WithRegion(testVal)) | ||
assert.Equal(t, c.opts.Region, testVal) | ||
}) | ||
|
||
t.Run("ModelID", func(t *testing.T) { | ||
t.Parallel() | ||
c := NewClient() | ||
assert.Equal(t, c.opts.ModelID, "") | ||
|
||
c = NewClient(WithModelID(bedrockModelID)) | ||
assert.Equal(t, c.opts.ModelID, bedrockModelID) | ||
}) | ||
|
||
t.Run("BedrockClient", func(t *testing.T) { | ||
t.Parallel() | ||
c := NewClient() | ||
assert.NotNil(t, c.opts.Client) | ||
assert.Equal(t, c.opts.Client.Options().Region, DefaultRegion) | ||
|
||
testVal := "us-west-1" | ||
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(testVal)) | ||
assert.NoError(t, err) | ||
bc := bedrockruntime.NewFromConfig(cfg) | ||
|
||
c = NewClient(WithBedrockClient(bc)) | ||
assert.NotNil(t, c.opts.Client) | ||
assert.Equal(t, c.opts.Client.Options().Region, testVal) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package bedrock | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
"github.com/milosgajdos/go-embeddings" | ||
) | ||
|
||
type Request struct { | ||
InputText string `json:"inputText"` | ||
} | ||
|
||
type Response struct { | ||
Embedding []float64 `json:"embedding"` | ||
InputTextTokenCount int `json:"inputTextTokenCount"` | ||
} | ||
|
||
func (e *Response) ToEmbeddings() ([]*embeddings.Embedding, error) { | ||
vals := make([]float64, len(e.Embedding)) | ||
copy(vals, e.Embedding) | ||
return []*embeddings.Embedding{ | ||
{Vector: vals}, | ||
}, nil | ||
} | ||
|
||
// Embed returns embeddings for every object in EmbeddingRequest. | ||
func (c *Client) Embed(ctx context.Context, embReq *Request) ([]*embeddings.Embedding, error) { | ||
payload, err := json.Marshal(embReq) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
resp, err := c.opts.Client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ | ||
Body: payload, | ||
ModelId: aws.String(c.opts.ModelID), | ||
ContentType: aws.String("application/json"), | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var embs Response | ||
if err = json.Unmarshal(resp.Body, &embs); err != nil { | ||
return nil, nil | ||
} | ||
|
||
return embs.ToEmbeddings() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"flag" | ||
"fmt" | ||
"log" | ||
|
||
"github.com/milosgajdos/go-embeddings/bedrock" | ||
) | ||
|
||
var ( | ||
input string | ||
model string | ||
) | ||
|
||
func init() { | ||
flag.StringVar(&input, "input", "what is life", "input data") | ||
flag.StringVar(&model, "model", bedrock.TitanTextV1.String(), "model name") | ||
} | ||
|
||
func main() { | ||
flag.Parse() | ||
|
||
c := bedrock.NewClient(bedrock.WithModelID(model)) | ||
|
||
embReq := &bedrock.Request{ | ||
InputText: input, | ||
} | ||
|
||
embs, err := c.Embed(context.Background(), embReq) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Printf("got %d embeddings", len(embs)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.