diff --git a/example-spec.yaml b/example-spec.yaml index 4bd3a55..11c8424 100644 --- a/example-spec.yaml +++ b/example-spec.yaml @@ -56,9 +56,11 @@ role: generate # the starenv derefers, lambdafy adds the following derefers: # # - lambdafy_sqs_send: This derefer will be replaced with a URL which when POSTed -# to will send a message to the SQS queue whose ARN is specified. The body -# of the POST will be sent as the SQS message body. If header -# 'Lambdafy-SQS-Group-Id' is set, it will be used as Group ID for the +# to will send a message to the SQS queue whose ARN is specified. This accepts +# either a JSON array of messages or a single message. If an array, the body +# of the POST will be split into batches and sent as entries in SQS send message batch. +# Otherwise, if a single messsage, the body of the POST will be sent as the SQS message body. +# If header 'Lambdafy-SQS-Group-Id' is set, it will be used as Group ID for the # message. A 2xx/3xx response is considered a success, otherwise a fail. See # the example below for usage. # Note: The necessary IAM role permissions to send SQS messages are added diff --git a/proxy/sqs.go b/proxy/sqs.go index b7b5c01..6f87ec5 100644 --- a/proxy/sqs.go +++ b/proxy/sqs.go @@ -3,23 +3,29 @@ package main import ( "context" "encoding/hex" + "encoding/json" "fmt" "io" "io/ioutil" "log" "math/rand" + "mime" "net/http" "net/url" "regexp" "strconv" "strings" + "time" "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" sqs "github.com/aws/aws-sdk-go-v2/service/sqs" + sqstypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" ) +const maxSQSBatchSize = 10 // SQS allows a maximum of 10 messages per batch + var sqsARNPat = regexp.MustCompile(`^arn:aws:sqs:([^:]+):([^:]+):(.+)$`) // getSQSQueueURL returns the URL of the SQS queue given its ARN. @@ -128,10 +134,13 @@ func (d sqsSendDerefer) Deref(arn string) (string, error) { var sqsIDToQueueURL = sqsSendDerefer{} const sqsGroupIDHeader = "Lambdafy-SQS-Group-Id" +const batchMessageHeader = "Lambdafy-SQS-Batch-Message" // handleSQSSend handles HTTP POST requests and translates them to SQS send // message. // Lambdafy-SQS-Group-Id header is used to set the message group ID. +// Lambdafy-SQS-Batch-Message header is used to indicate that the request body +// contains a JSON array of messages to be sent in a batch. func handleSQSSend(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -168,18 +177,122 @@ func handleSQSSend(w http.ResponseWriter, r *http.Request) { } sqsCl := sqs.NewFromConfig(c) - if _, err := sqsCl.SendMessage(context.Background(), &sqs.SendMessageInput{ - MessageBody: aws.String(string(body)), - QueueUrl: aws.String(qURL), - MessageGroupId: groupID, - }); err != nil { - log.Printf("error sending SQS message: %v", err) - http.Error(w, fmt.Sprintf("Error sending SQS message: %v", err), http.StatusInternalServerError) + isBatchMessage := r.Header.Get(batchMessageHeader) != "" + // Single message - use regular send + if !isBatchMessage { + if _, err := sqsCl.SendMessage(context.Background(), &sqs.SendMessageInput{ + MessageBody: aws.String(string(body)), + QueueUrl: aws.String(qURL), + MessageGroupId: groupID, + }); err != nil { + log.Printf("error sending SQS message: %v", err) + http.Error(w, fmt.Sprintf("Error sending SQS message: %v", err), http.StatusInternalServerError) + return + } + + log.Printf("sent an SQS message to '%s'", qURL) + return + } + + // Batch send message - expect the correct Content-Type and + // a JSON array of string messages in the request body + + // Check if the Content-Type media type is application/json + // instead of direct string equality check, as it may contain additional parameters. + contentType := r.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + + if err != nil { + log.Printf("error parsing Content-Type header: %v", err) + http.Error(w, fmt.Sprintf("Error parsing Content-Type header: %v", err), http.StatusBadRequest) + return + } + if mediaType != "application/json" { + http.Error(w, "Content-Type must be application/json for batch messages", http.StatusBadRequest) + return + } + + var messages []string + if err := json.Unmarshal(body, &messages); err != nil { + log.Printf("Send message batch failure - Invalid JSON array: %v", err) + http.Error(w, "Invalid JSON array", http.StatusBadRequest) return } - log.Printf("sent an SQS message to '%s'", qURL) + if len(messages) == 0 { + log.Printf("Send message batch failure - Empty message array") + http.Error(w, "Empty message array", http.StatusBadRequest) + return + } + if len(messages) > maxSQSBatchSize { + log.Printf("Send message batch failure - Too many messages in batch, maximum is %d", maxSQSBatchSize) + http.Error(w, fmt.Sprintf("Too many messages in batch, maximum is %d", maxSQSBatchSize), http.StatusBadRequest) + return + } + + entries := make([]sqstypes.SendMessageBatchRequestEntry, len(messages)) + for j, msg := range messages { + entries[j] = sqstypes.SendMessageBatchRequestEntry{ + Id: aws.String(fmt.Sprintf("%d", j)), + MessageBody: aws.String(msg), + MessageGroupId: groupID, + } + } + + var attempts int = 0 + var retryable_entries []sqstypes.SendMessageBatchRequestEntry = entries + var nonRetryableEntries []sqstypes.SendMessageBatchRequestEntry = nil + + for (attempts == 0 || len(retryable_entries) > 0) && attempts < 5 { + // Sleep for exponential backoff on retry + if attempts > 0 { + // bit shift to calculate the sleep duration -> 500ms, 1s, 2s, 4s, 8s + sleepDuration := (1 << attempts) * 500 // Exponential backoff in milliseconds + time.Sleep(time.Duration(sleepDuration) * time.Millisecond) + } + + attempts++ + output, err := sqsCl.SendMessageBatch(context.Background(), &sqs.SendMessageBatchInput{ + QueueUrl: aws.String(qURL), + Entries: retryable_entries, + }) + if err != nil { + log.Printf("error sending SQS message batch: %v", err) + http.Error(w, fmt.Sprintf("Error sending SQS message batch: %v", err), http.StatusInternalServerError) + return + } + retryable_entries = nil // Reset retryable entries for the next attempt + if len(output.Failed) > 0 { + log.Printf("failed to send %d SQS messages in batch", len(output.Failed)) + for _, f := range output.Failed { + fmt.Printf( + "failed to send SQS message %s: %s (SenderFault: %t, Code: %s)\n", + *f.Id, *f.Message, f.SenderFault, *f.Code, + ) + id, err := strconv.Atoi(*f.Id) + if err != nil { + log.Printf("error parsing SQS message ID '%s': %v", *f.Id, err) + http.Error(w, fmt.Sprintf("Error parsing SQS message ID '%s': %v", *f.Id, err), http.StatusInternalServerError) + return + } + if f.SenderFault { + // Non-retryable error + nonRetryableEntries = append(nonRetryableEntries, entries[id]) + } else { + // Retryable error + retryable_entries = append(retryable_entries, entries[id]) + } + } + } + } + + if len(retryable_entries)+len(nonRetryableEntries) > 0 { + log.Printf("%d of %d SQS messages in batch failed", len(retryable_entries)+len(nonRetryableEntries), len(entries)) + http.Error(w, fmt.Sprintf("%d of %d SQS messages in batch failed", len(retryable_entries)+len(nonRetryableEntries), len(entries)), http.StatusInternalServerError) + return + } + log.Printf("sent %d SQS messages to '%s'", len(messages), qURL) } const sendSQSStarenvTag = "lambdafy_sqs_send"