-
Notifications
You must be signed in to change notification settings - Fork 58
feat: add support for accessing AWS Neptune using assume role credentials #818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
73df1bb
f74cde3
6a2045c
5b96d96
04899a6
0fdb2f6
61e426d
b8509f3
dfc34dc
4cfdf4e
90fedc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,17 +14,28 @@ import { clientRoot, proxyServerRoot } from "./paths.js"; | |
import { errorHandlingMiddleware, handleError } from "./error-handler.js"; | ||
import { BooleanStringSchema, env } from "./env.js"; | ||
import { pipeline } from "stream"; | ||
import { AssumeRoleCommand, STSClient } from "@aws-sdk/client-sts"; | ||
|
||
const app = express(); | ||
|
||
const DEFAULT_SERVICE_TYPE = "neptune-db"; | ||
|
||
interface AwsCredentials { | ||
accessKeyId: string; | ||
secretAccessKey: string; | ||
sessionToken?: string; | ||
expiration?: Date; | ||
} | ||
|
||
const credentialCache: { [roleArn: string]: AwsCredentials } = {}; | ||
|
||
interface DbQueryIncomingHttpHeaders extends IncomingHttpHeaders { | ||
queryid?: string; | ||
"graph-db-connection-url"?: string; | ||
"aws-neptune-region"?: string; | ||
"service-type"?: string; | ||
"db-query-logging-enabled"?: string; | ||
"aws-assume-role-arn"?: string; | ||
} | ||
|
||
interface LoggerIncomingHttpHeaders extends IncomingHttpHeaders { | ||
|
@@ -34,23 +45,97 @@ interface LoggerIncomingHttpHeaders extends IncomingHttpHeaders { | |
|
||
app.use(requestLoggingMiddleware()); | ||
|
||
// Function to get IAM headers for AWS4 signing process. | ||
async function getIAMHeaders(options: string | aws4.Request) { | ||
// Function to check if the credentials are valid. | ||
function areCredentialsValid(creds: AwsCredentials): boolean { | ||
return creds.expiration | ||
? new Date(creds.expiration).getTime() - Date.now() > 5 * 60 * 1000 | ||
: true; | ||
} | ||
|
||
async function getCredentials( | ||
awsAssumeRoleArn: string | undefined, | ||
region: string | undefined | ||
) { | ||
if (awsAssumeRoleArn !== undefined && awsAssumeRoleArn !== "") { | ||
if ( | ||
credentialCache[awsAssumeRoleArn] && | ||
areCredentialsValid(credentialCache[awsAssumeRoleArn]) | ||
) { | ||
return credentialCache[awsAssumeRoleArn]; | ||
} | ||
try { | ||
const command = new AssumeRoleCommand({ | ||
RoleArn: awsAssumeRoleArn, | ||
RoleSessionName: "GraphExplorerProxyServer", | ||
}); | ||
const stsClient = new STSClient({ region: region }); | ||
const { Credentials } = await stsClient.send(command); | ||
|
||
if ( | ||
!Credentials || | ||
!Credentials.AccessKeyId || | ||
!Credentials.SecretAccessKey || | ||
!Credentials.SessionToken | ||
) { | ||
throw new Error("Failed to assume role, no credentials returned"); | ||
} | ||
|
||
proxyLogger.debug( | ||
"Assumed role successfully using the provided role ARN %s, it will expire at: %s", | ||
awsAssumeRoleArn, | ||
Credentials.Expiration | ||
); | ||
credentialCache[awsAssumeRoleArn] = { | ||
accessKeyId: Credentials.AccessKeyId, | ||
secretAccessKey: Credentials.SecretAccessKey, | ||
sessionToken: Credentials.SessionToken, | ||
expiration: Credentials.Expiration, | ||
}; | ||
|
||
return credentialCache[awsAssumeRoleArn]; | ||
} catch (error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible this It might be worth combining the Then this function becomes something more like (pseudo code): function getIAMHeaders() {
const creds = getCredentials();
if (!creds) {
throw new Error();
}
return aws4.sign(options, {
accessKeyId: creds.AccessKeyId,
secretAccessKey: creds.SecretAccessKey,
...(creds.SessionToken && {
sessionToken: creds.SessionToken,
}),
});
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense, I can make that change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in - dfc34dc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice. I'll take another look at this soon. |
||
proxyLogger.error( | ||
"IAM is enabled but credentials cannot be assumed using the provided role ARN: %s, Error: %s", | ||
awsAssumeRoleArn, | ||
error | ||
); | ||
return undefined; | ||
} | ||
} | ||
|
||
const credentialProvider = fromNodeProviderChain(); | ||
const creds = await credentialProvider(); | ||
if (creds === undefined) { | ||
throw new Error( | ||
proxyLogger.error( | ||
"IAM is enabled but credentials cannot be found on the credential provider chain." | ||
); | ||
return undefined; | ||
} | ||
|
||
const headers = aws4.sign(options, { | ||
return { | ||
accessKeyId: creds.accessKeyId, | ||
secretAccessKey: creds.secretAccessKey, | ||
sessionToken: creds.sessionToken, | ||
}; | ||
} | ||
|
||
// Function to get IAM headers for AWS4 signing process. | ||
async function getIAMHeaders( | ||
options: string | aws4.Request, | ||
region: string | undefined, | ||
awsAssumeRoleArn: string | undefined | ||
) { | ||
const creds = await getCredentials(awsAssumeRoleArn, region); | ||
if (!creds) { | ||
throw new Error( | ||
"IAM is enabled but credentials cannot be found or assumed." | ||
); | ||
} | ||
return aws4.sign(options, { | ||
accessKeyId: creds.accessKeyId, | ||
secretAccessKey: creds.secretAccessKey, | ||
...(creds.sessionToken && { sessionToken: creds.sessionToken }), | ||
}); | ||
|
||
return headers; | ||
} | ||
|
||
// Function to retry fetch requests with exponential backoff. | ||
|
@@ -60,20 +145,25 @@ const retryFetch = async ( | |
isIamEnabled: boolean, | ||
region: string | undefined, | ||
serviceType: string, | ||
awsAssumeRoleArn: string | undefined, | ||
retryDelay = 10000, | ||
refetchMaxRetries = 1 | ||
) => { | ||
for (let i = 0; i < refetchMaxRetries; i++) { | ||
if (isIamEnabled) { | ||
const data = await getIAMHeaders({ | ||
host: url.hostname, | ||
port: url.port, | ||
path: url.pathname + url.search, | ||
service: serviceType, | ||
const data = await getIAMHeaders( | ||
{ | ||
host: url.hostname, | ||
port: url.port, | ||
path: url.pathname + url.search, | ||
service: serviceType, | ||
region, | ||
method: options.method, | ||
body: options.body ?? undefined, | ||
}, | ||
region, | ||
method: options.method, | ||
body: options.body ?? undefined, | ||
}); | ||
awsAssumeRoleArn | ||
); | ||
|
||
options = { | ||
host: url.hostname, | ||
|
@@ -130,15 +220,17 @@ async function fetchData( | |
options: RequestInit, | ||
isIamEnabled: boolean, | ||
region: string | undefined, | ||
serviceType: string | ||
serviceType: string, | ||
awsAssumeRoleArn: string | undefined | ||
) { | ||
try { | ||
const response = await retryFetch( | ||
new URL(url), | ||
options, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
|
||
// Set the headers from the fetch response to the client response | ||
|
@@ -201,6 +293,7 @@ app.post("/sparql", (req, res, next) => { | |
const serviceType = isIamEnabled | ||
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE) | ||
: ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
/// Function to cancel long running queries if the client disappears before completion | ||
async function cancelQuery() { | ||
|
@@ -221,7 +314,8 @@ app.post("/sparql", (req, res, next) => { | |
}, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
} catch (err) { | ||
// Not really an error | ||
|
@@ -275,7 +369,8 @@ app.post("/sparql", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
@@ -293,6 +388,7 @@ app.post("/gremlin", (req, res, next) => { | |
const serviceType = isIamEnabled | ||
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE) | ||
: ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
// Validate the input before making any external calls. | ||
const queryString = req.body.query; | ||
|
@@ -320,7 +416,8 @@ app.post("/gremlin", (req, res, next) => { | |
{ method: "GET" }, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
} catch (err) { | ||
// Not really an error | ||
|
@@ -360,7 +457,8 @@ app.post("/gremlin", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
@@ -398,6 +496,7 @@ app.post("/openCypher", (req, res, next) => { | |
const serviceType = isIamEnabled | ||
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE) | ||
: ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
return fetchData( | ||
res, | ||
|
@@ -406,7 +505,8 @@ app.post("/openCypher", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
@@ -424,6 +524,7 @@ app.get("/summary", (req, res, next) => { | |
}; | ||
|
||
const region = isIamEnabled ? headers["aws-neptune-region"] : ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
fetchData( | ||
res, | ||
|
@@ -432,7 +533,8 @@ app.get("/summary", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
@@ -450,6 +552,7 @@ app.get("/pg/statistics/summary", (req, res, next) => { | |
}; | ||
|
||
const region = isIamEnabled ? headers["aws-neptune-region"] : ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
fetchData( | ||
res, | ||
|
@@ -458,7 +561,8 @@ app.get("/pg/statistics/summary", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
@@ -476,6 +580,7 @@ app.get("/rdf/statistics/summary", (req, res, next) => { | |
}; | ||
|
||
const region = isIamEnabled ? headers["aws-neptune-region"] : ""; | ||
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : ""; | ||
|
||
fetchData( | ||
res, | ||
|
@@ -484,7 +589,8 @@ app.get("/rdf/statistics/summary", (req, res, next) => { | |
requestOptions, | ||
isIamEnabled, | ||
region, | ||
serviceType | ||
serviceType, | ||
awsAssumeRoleArn | ||
); | ||
}); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this time range something that is consistent for all users and situations? Perhaps there is a better way to check for expiration instead of checking against a hard coded timespan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually credentials do come with the expiration time and we would need to check against that time if the credentials are valid or not. One other way I could think of is it let it fail and then retry with the new credentials. But I think that is not very efficient approach. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. My immediate thought would be
It surely complicates the logic, but I think it will yield the best outcome.