diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index bdf8f034..b129f6ac 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -67,8 +67,8 @@ extension Lambda { /// /// It accepts three types of requests from the Lambda function (through the LambdaRuntimeClient): /// 1. GET /next - the lambda function polls this endpoint to get the next invocation request -/// 2. POST /:requestID/response - the lambda function posts the response to the invocation request -/// 3. POST /:requestID/error - the lambda function posts an error response to the invocation request +/// 2. POST /:requestId/response - the lambda function posts the response to the invocation request +/// 3. POST /:requestId/error - the lambda function posts an error response to the invocation request /// /// It also accepts one type of request from the client invoking the lambda function: /// 1. POST /invoke - the client posts the event to the lambda function @@ -235,6 +235,7 @@ internal struct LambdaHTTPServer { var requestHead: HTTPRequestHead! var requestBody: ByteBuffer? + var requestId: String? // Note that this method is non-throwing and we are catching any error. // We do this since we don't want to tear down the whole server when a single connection @@ -246,27 +247,53 @@ internal struct LambdaHTTPServer { switch inboundData { case .head(let head): requestHead = head + requestId = getRequestId(from: requestHead) + + // for streaming requests, push a partial head response + if self.isStreamingResponse(requestHead) { + await self.responsePool.push( + LocalServerResponse( + id: requestId, + status: .ok + ) + ) + } case .body(let body): - requestBody.setOrWriteImmutableBuffer(body) + precondition(requestHead != nil, "Received .body without .head") + + // if this is a request from a Streaming Lambda Handler, + // stream the response instead of buffering it + if self.isStreamingResponse(requestHead) { + await self.responsePool.push( + LocalServerResponse(id: requestId, body: body) + ) + } else { + requestBody.setOrWriteImmutableBuffer(body) + } case .end: precondition(requestHead != nil, "Received .end without .head") - // process the request - let response = try await self.processRequest( - head: requestHead, - body: requestBody, - logger: logger - ) - // send the responses - try await self.sendResponse( - response: response, - outbound: outbound, - logger: logger - ) + if self.isStreamingResponse(requestHead) { + // for streaming response, send the final response + await self.responsePool.push( + LocalServerResponse(id: requestId, final: true) + ) + } else { + // process the buffered response for non streaming requests + try await self.processRequestAndSendResponse( + head: requestHead, + body: requestBody, + outbound: outbound, + logger: logger + ) + } + + // reset the request state for next request requestHead = nil requestBody = nil + requestId = nil } } } @@ -281,29 +308,46 @@ internal struct LambdaHTTPServer { } } + /// This function checks if the request is a streaming response request + /// verb = POST, uri = :requestId/response, HTTP Header contains "Transfer-Encoding: chunked" + private func isStreamingResponse(_ requestHead: HTTPRequestHead) -> Bool { + requestHead.method == .POST && requestHead.uri.hasSuffix(Consts.postResponseURLSuffix) + && requestHead.headers.contains(name: "Transfer-Encoding") + && (requestHead.headers["Transfer-Encoding"].contains("chunked") + || requestHead.headers["Transfer-Encoding"].contains("Chunked")) + } + + /// This function parses and returns the requestId or nil if the request doesn't contain a requestId + private func getRequestId(from head: HTTPRequestHead) -> String? { + let parts = head.uri.split(separator: "/") + return parts.count > 2 ? String(parts[parts.count - 2]) : nil + } /// This function process the URI request sent by the client and by the Lambda function /// /// It enqueues the client invocation and iterate over the invocation queue when the Lambda function sends /next request - /// It answers the /:requestID/response and /:requestID/error requests sent by the Lambda function but do not process the body + /// It answers the /:requestId/response and /:requestId/error requests sent by the Lambda function but do not process the body /// /// - Parameters: /// - head: the HTTP request head /// - body: the HTTP request body /// - Throws: /// - Returns: the response to send back to the client or the Lambda function - private func processRequest( + private func processRequestAndSendResponse( head: HTTPRequestHead, body: ByteBuffer?, + outbound: NIOAsyncChannelOutboundWriter, logger: Logger - ) async throws -> LocalServerResponse { + ) async throws { + var logger = logger + logger[metadataKey: "URI"] = "\(head.method) \(head.uri)" if let body { logger.trace( "Processing request", - metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"] + metadata: ["Body": "\(String(buffer: body))"] ) } else { - logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"]) + logger.trace("Processing request") } switch (head.method, head.uri) { @@ -314,27 +358,36 @@ internal struct LambdaHTTPServer { // client POST /invoke case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): guard let body else { - return .init(status: .badRequest, headers: [], body: nil) + return try await sendResponse( + .init(status: .badRequest, final: true), + outbound: outbound, + logger: logger + ) } // we always accept the /invoke request and push them to the pool let requestId = "\(DispatchTime.now().uptimeNanoseconds)" - var logger = logger - logger[metadataKey: "requestID"] = "\(requestId)" - logger.trace("/invoke received invocation") + logger[metadataKey: "requestId"] = "\(requestId)" + logger.trace("/invoke received invocation, pushing it to the pool and wait for a lambda response") await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) // wait for the lambda function to process the request for try await response in self.responsePool { - logger.trace( - "Received response to return to client", - metadata: ["requestId": "\(response.requestId ?? "")"] - ) + logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")" + logger.trace("Received response to return to client") if response.requestId == requestId { - return response + logger.trace("/invoke requestId is valid, sending the response") + // send the response to the client + // if the response is final, we can send it and return + // if the response is not final, we can send it and wait for the next response + try await self.sendResponse(response, outbound: outbound, logger: logger) + if response.final == true { + logger.trace("/invoke returning") + return // if the response is final, we can return and close the connection + } } else { logger.error( "Received response for a different request id", - metadata: ["response requestId": "\(response.requestId ?? "")", "requestId": "\(requestId)"] + metadata: ["response requestId": "\(response.requestId ?? "")"] ) // should we return an error here ? Or crash as this is probably a programming error? } @@ -345,7 +398,11 @@ internal struct LambdaHTTPServer { // client uses incorrect HTTP method case (_, let url) where url.hasSuffix(self.invocationEndpoint): - return .init(status: .methodNotAllowed) + return try await sendResponse( + .init(status: .methodNotAllowed, final: true), + outbound: outbound, + logger: logger + ) // // lambda invocations @@ -358,85 +415,112 @@ internal struct LambdaHTTPServer { // pop the tasks from the queue logger.trace("/next waiting for /invoke") for try await invocation in self.invocationPool { - logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"]) - // this call also stores the invocation requestId into the response - return invocation.makeResponse(status: .accepted) + logger[metadataKey: "requestId"] = "\(invocation.requestId)" + logger.trace("/next retrieved invocation") + // tell the lambda function we accepted the invocation + return try await sendResponse(invocation.acceptedResponse(), outbound: outbound, logger: logger) } // What todo when there is no more tasks to process? // This should not happen as the async iterator blocks until there is a task to process fatalError("No more invocations to process - the async for loop should not return") - // :requestID/response endpoint is called by the lambda posting the response + // :requestId/response endpoint is called by the lambda posting the response case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): - let parts = head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + guard let requestId = getRequestId(from: head) else { // the request is malformed, since we were expecting a requestId in the path - return .init(status: .badRequest) + return try await sendResponse( + .init(status: .badRequest, final: true), + outbound: outbound, + logger: logger + ) } // enqueue the lambda function response to be served as response to the client /invoke - logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestID)"]) + logger.trace("/:requestId/response received response", metadata: ["requestId": "\(requestId)"]) await self.responsePool.push( LocalServerResponse( - id: requestID, + id: requestId, status: .ok, - headers: [("Content-Type", "application/json")], - body: body + // the local server has no mecanism to collect headers set by the lambda function + headers: HTTPHeaders(), + body: body, + final: true ) ) // tell the Lambda function we accepted the response - return .init(id: requestID, status: .accepted) + return try await sendResponse( + .init(id: requestId, status: .accepted, final: true), + outbound: outbound, + logger: logger + ) - // :requestID/error endpoint is called by the lambda posting an error response - // we accept all requestID and we do not handle the body, we just acknowledge the request + // :requestId/error endpoint is called by the lambda posting an error response + // we accept all requestId and we do not handle the body, we just acknowledge the request case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): - let parts = head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + guard let requestId = getRequestId(from: head) else { // the request is malformed, since we were expecting a requestId in the path - return .init(status: .badRequest) + return try await sendResponse( + .init(status: .badRequest, final: true), + outbound: outbound, + logger: logger + ) } // enqueue the lambda function response to be served as response to the client /invoke - logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestID)"]) + logger.trace("/:requestId/response received response", metadata: ["requestId": "\(requestId)"]) await self.responsePool.push( LocalServerResponse( - id: requestID, + id: requestId, status: .internalServerError, - headers: [("Content-Type", "application/json")], - body: body + headers: HTTPHeaders([("Content-Type", "application/json")]), + body: body, + final: true ) ) - return .init(status: .accepted) + return try await sendResponse(.init(status: .accepted, final: true), outbound: outbound, logger: logger) // unknown call default: - return .init(status: .notFound) + return try await sendResponse(.init(status: .notFound, final: true), outbound: outbound, logger: logger) } } private func sendResponse( - response: LocalServerResponse, + _ response: LocalServerResponse, outbound: NIOAsyncChannelOutboundWriter, logger: Logger ) async throws { - var headers = HTTPHeaders(response.headers ?? []) - headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)") - - logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"]) - try await outbound.write( - HTTPServerResponsePart.head( - HTTPResponseHead( - version: .init(major: 1, minor: 1), - status: response.status, - headers: headers + var logger = logger + logger[metadataKey: "requestId"] = "\(response.requestId ?? "nil")" + logger.trace("Writing response for \(response.status?.code ?? 0)") + + var headers = response.headers ?? HTTPHeaders() + if let body = response.body { + headers.add(name: "Content-Length", value: "\(body.readableBytes)") + } + + if let status = response.status { + logger.trace("Sending status and headers") + try await outbound.write( + HTTPServerResponsePart.head( + HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: status, + headers: headers + ) ) ) - ) + } + if let body = response.body { + logger.trace("Sending body") try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(body))) } - try await outbound.write(HTTPServerResponsePart.end(nil)) + if response.final { + logger.trace("Sending end") + try await outbound.write(HTTPServerResponsePart.end(nil)) + } } /// A shared data structure to store the current invocation or response requests and the continuation objects. @@ -520,15 +604,22 @@ internal struct LambdaHTTPServer { private struct LocalServerResponse: Sendable { let requestId: String? - let status: HTTPResponseStatus - let headers: [(String, String)]? + let status: HTTPResponseStatus? + let headers: HTTPHeaders? let body: ByteBuffer? - init(id: String? = nil, status: HTTPResponseStatus, headers: [(String, String)]? = nil, body: ByteBuffer? = nil) - { + let final: Bool + init( + id: String? = nil, + status: HTTPResponseStatus? = nil, + headers: HTTPHeaders? = nil, + body: ByteBuffer? = nil, + final: Bool = false + ) { self.requestId = id self.status = status self.headers = headers self.body = body + self.final = final } } @@ -536,10 +627,10 @@ internal struct LambdaHTTPServer { let requestId: String let request: ByteBuffer - func makeResponse(status: HTTPResponseStatus) -> LocalServerResponse { + func acceptedResponse() -> LocalServerResponse { // required headers - let headers = [ + let headers = HTTPHeaders([ (AmazonHeaders.requestID, self.requestId), ( AmazonHeaders.invokedFunctionARN, @@ -547,9 +638,15 @@ internal struct LambdaHTTPServer { ), (AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"), (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), - ] - - return LocalServerResponse(id: self.requestId, status: status, headers: headers, body: self.request) + ]) + + return LocalServerResponse( + id: self.requestId, + status: .accepted, + headers: headers, + body: self.request, + final: true + ) } } }