diff --git a/src/cli.ts b/src/cli.ts index 079c143b..de61e462 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -78,6 +78,10 @@ async function main(argv: string[]) { "--https-ca ", "Enable HTTPS server where the certificate was generated by this CA. Useful if you are using a self-signed certificate. Also requires --https-key and --https-cert." ) + .option( + "--unframe-grpc-web-json-requests-hostname [hostname...]", + "Rewrite received requests whose content-type is application/grpc-web+json to be application/json, mutating the body of the request accordingly. This is useful if you want plain text tape records rather than binary data. The gRPC server needs to support receiving unframed requests for this option to be useful." + ) .option( "--rewrite-before-diff [s/find/replace/g...]", "Provide regex-based rewrite rules over strings before passing them to the diffing algorithm. The regex rules use sed-style syntax. s/find/replace/ with an optional regex modifier suffixes. Capture groups can be used using sed-style \\N syntax. This is only used during replaying existing tapes.", @@ -96,6 +100,8 @@ async function main(argv: string[]) { const httpsCA: string = program.httpsCa || ""; const httpsKey: string = program.httpsKey; const httpsCert: string = program.httpsCert; + const unframeGrpcWebJsonRequestsHostnames: string[] = + program.unframeGrpcWebJsonRequestsHostname; const rewriteBeforeDiffRules: RewriteRules = program.rewriteBeforeDiff; switch (initialMode) { @@ -147,6 +153,7 @@ async function main(argv: string[]) { httpsCA, httpsKey, httpsCert, + unframeGrpcWebJsonRequestsHostnames, rewriteBeforeDiffRules, }); await server.start(port); diff --git a/src/server.ts b/src/server.ts index 2acc4c11..2295042e 100644 --- a/src/server.ts +++ b/src/server.ts @@ -28,6 +28,7 @@ export class RecordReplayServer { private defaultTape: string; private replayedTapes: Set = new Set(); private preventConditionalRequests?: boolean; + private unframeGrpcWebJsonRequestsHostnames: string[]; private rewriteBeforeDiffRules: RewriteRules; constructor(options: { @@ -42,6 +43,7 @@ export class RecordReplayServer { httpsCA?: string; httpsKey?: string; httpsCert?: string; + unframeGrpcWebJsonRequestsHostnames?: string[]; rewriteBeforeDiffRules?: RewriteRules; }) { this.currentTapeRecords = []; @@ -53,6 +55,8 @@ export class RecordReplayServer { this.persistence = new Persistence(options.tapeDir, redactHeaders); this.defaultTape = options.defaultTapeName; this.preventConditionalRequests = options.preventConditionalRequests; + this.unframeGrpcWebJsonRequestsHostnames = + options.unframeGrpcWebJsonRequestsHostnames || []; this.rewriteBeforeDiffRules = options.rewriteBeforeDiffRules || new RewriteRules(); this.loadTape(this.defaultTape); @@ -74,15 +78,6 @@ export class RecordReplayServer { return; } - if ( - this.preventConditionalRequests && - (req.method === "GET" || req.method === "HEAD") - ) { - // Headers are always coming in as lowercase. - delete req.headers["if-modified-since"]; - delete req.headers["if-none-match"]; - } - try { const request: Request = { method: req.method, @@ -90,6 +85,8 @@ export class RecordReplayServer { headers: req.headers, body: await receiveRequestBody(req), }; + + // Is this a proxay API call? if ( request.path === "/__proxay" || request.path.startsWith("/__proxay/") @@ -97,6 +94,11 @@ export class RecordReplayServer { this.handleProxayApi(request, res); return; } + + // Potentially rewrite the request before processing it at all. + this.rewriteRequest(request); + + // Process the request. const record = await this.fetchResponse(request); if (record) { this.sendResponse(record, res); @@ -240,6 +242,75 @@ export class RecordReplayServer { res.end(`Unhandled proxay request.\n\n${JSON.stringify(request)}`); } + /** + * Potentially rewrite the request before processing it. + */ + private rewriteRequest(request: Request) { + // Grab the `host` header of the request. + const hostname = (request.headers.host || null) as string | null; + + // Potentially prevent 304 responses from being able to be generated. + if ( + this.preventConditionalRequests && + (request.method === "GET" || request.method === "HEAD") + ) { + // Headers are always coming in as lowercase. + delete request.headers["if-modified-since"]; + delete request.headers["if-none-match"]; + } + + // Potentially unframe a grpc-web+json request. + if ( + request.method === "POST" && + request.headers["content-type"] === "application/grpc-web+json" && + hostname != null && + this.unframeGrpcWebJsonRequestsHostnames.includes(hostname) + ) { + this.rewriteGrpcWebJsonRequest(request); + } + } + + /** + * Rewrite a gRPC-web+json request to be unframed. + */ + private rewriteGrpcWebJsonRequest(request: Request) { + /** + * From the gRPC specification (https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md) + * + * The repeated sequence of Length-Prefixed-Message items is delivered in DATA frames: + * Length-Prefixed-Message → Compressed-Flag Message-Length Message + * Compressed-Flag → 0 / 1 # encoded as 1 byte unsigned integer + * Message-Length → {length of Message} # encoded as 4 byte unsigned integer (big endian) + * Message → *{binary octet} + */ + const compressionFlag = request.body.readUInt8(0); + const messageLength = request.body.readUInt32BE(1); + if (compressionFlag !== 0) { + console.error( + `The gRPC-web compression flag was set to 1. Do not know how to handle compressed request paylods. Aborting gRPC-web+json rewrite.` + ); + return; + } + + // Sanity check the content length. + const rawContentLength = request.headers["content-length"] as + | string + | undefined; + if (rawContentLength !== undefined) { + const contentLength = parseInt(rawContentLength, 10); + if (contentLength !== messageLength + 5) { + console.log( + `Unexpected content length. Header says "${rawContentLength}". gRPC payload length preamble says "${messageLength}".` + ); + } + } + + // Rewrite the request to be unframed. + request.headers["content-length"] = messageLength.toString(); + request.headers["content-type"] = "application/json"; + request.body = request.body.subarray(5); + } + private async fetchResponse(request: Request): Promise { switch (this.mode) { case "replay": diff --git a/src/tests/passthrough.spec.ts b/src/tests/passthrough.spec.ts index 262167a4..a7b0a512 100644 --- a/src/tests/passthrough.spec.ts +++ b/src/tests/passthrough.spec.ts @@ -1,9 +1,10 @@ import axios from "axios"; -import { PROXAY_HOST } from "./config"; +import { PROXAY_HOST, TEST_SERVER_PORT } from "./config"; import { setupServers } from "./setup"; import { BINARY_PATH, BINARY_RESPONSE, + GRPC_WEB_JSON_PATH, SIMPLE_TEXT_PATH, SIMPLE_TEXT_RESPONSE, UTF8_PATH, @@ -36,3 +37,85 @@ describe("Passthrough", () => { }); }); }); + +describe("Passthrough with grpc-web+json unframing", () => { + setupServers({ mode: "passthrough", enableUnframeGrpcWebJson: true }); + + test("response: simple text", async () => { + const response = await axios.get(`${PROXAY_HOST}${SIMPLE_TEXT_PATH}`); + expect(response.data).toBe(SIMPLE_TEXT_RESPONSE); + }); + + test("response: utf-8", async () => { + const response = await axios.get(`${PROXAY_HOST}${UTF8_PATH}`); + expect(response.data).toBe(UTF8_RESPONSE); + }); + + test("response: binary", async () => { + const response = await axios.get(`${PROXAY_HOST}${BINARY_PATH}`, { + responseType: "arraybuffer", + }); + expect(response.data).toEqual(BINARY_RESPONSE); + }); + + test("can pick any tape", async () => { + await axios.post(`${PROXAY_HOST}/__proxay/tape`, { + tape: "new-tape", + }); + }); + + test("unframes a grpc-web+json request", async () => { + const requestBody = Buffer.from([ + 0, + 0, + 0, + 0, + 31, + 123, + 34, + 101, + 109, + 97, + 105, + 108, + 34, + 58, + 34, + 102, + 111, + 111, + 46, + 98, + 97, + 114, + 64, + 101, + 120, + 97, + 109, + 112, + 108, + 101, + 46, + 99, + 111, + 109, + 34, + 125, + ]); + const response = await axios.post( + `${PROXAY_HOST}${GRPC_WEB_JSON_PATH}`, + requestBody, + { + headers: { + "content-type": "application/grpc-web+json", + host: `localhost:${TEST_SERVER_PORT}`, + }, + } + ); + expect(response.headers["content-type"]).toBe( + "application/json; charset=utf-8" + ); + expect(response.data).toEqual({ email: "foo.bar@example.com" }); + }); +}); diff --git a/src/tests/setup.ts b/src/tests/setup.ts index b201ba63..de51afcf 100644 --- a/src/tests/setup.ts +++ b/src/tests/setup.ts @@ -9,10 +9,12 @@ export function setupServers({ mode, tapeDirName = mode, defaultTapeName = "default", + enableUnframeGrpcWebJson = false, }: { mode: Mode; tapeDirName?: string; defaultTapeName?: string; + enableUnframeGrpcWebJson?: boolean; }) { const tapeDir = path.join(__dirname, "tapes", tapeDirName); const servers = { tapeDir } as { @@ -35,6 +37,9 @@ export function setupServers({ defaultTapeName, host: TEST_SERVER_HOST, timeout: 100, + unframeGrpcWebJsonRequestsHostnames: enableUnframeGrpcWebJson + ? [`localhost:${TEST_SERVER_PORT}`] + : [], }); await Promise.all([ servers.proxy.start(PROXAY_PORT), diff --git a/src/tests/testserver.ts b/src/tests/testserver.ts index aa2e52d3..cafeb801 100644 --- a/src/tests/testserver.ts +++ b/src/tests/testserver.ts @@ -22,6 +22,8 @@ export const BINARY_RESPONSE = Buffer.from([ 179, ]); +export const GRPC_WEB_JSON_PATH = "/grpc-web-json"; + /** * A test server used as a fake backend. */ @@ -33,6 +35,7 @@ export class TestServer { constructor() { this.app = express(); this.app.use(express.json()); + this.app.use(express.raw({ type: () => true })); this.app.use((_req, _res, next) => { this.requestCount += 1; next(); @@ -46,6 +49,10 @@ export class TestServer { this.app.get(BINARY_PATH, (_req, res) => { res.send(BINARY_RESPONSE); }); + this.app.post(GRPC_WEB_JSON_PATH, (req, res) => { + res.setHeader("content-type", req.headers["content-type"] as string); + res.send(req.body); + }); this.app.get(JSON_IDENTITY_PATH, (req, res) => { res.json({ data: req.path, requestCount: this.requestCount }); });