Skip to content

feat: add support for accessing AWS Neptune using assume role credentials #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Upcoming

- **Improved** accessing neptune using assume role
([#818](https://github.com/aws/graph-explorer/pull/818))
- **Added** ability to restore the graph from the previous session
([#826](https://github.com/aws/graph-explorer/pull/826))
- **Updated** styling of the connections screen
Expand Down
1 change: 1 addition & 0 deletions packages/graph-explorer-proxy-server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"license": "Apache-2.0",
"dependencies": {
"@aws-sdk/credential-providers": "^3.758.0",
"@aws-sdk/client-sts": "^3.758.0",
"@graph-explorer/shared": "workspace:*",
"aws4": "^1.13.2",
"body-parser": "^1.20.3",
Expand Down
154 changes: 130 additions & 24 deletions packages/graph-explorer-proxy-server/src/node-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,28 @@ import { clientRoot, proxyServerRoot } from "./paths.js";
import { errorHandlingMiddleware, handleError } from "./error-handler.js";
import { BooleanStringSchema, env } from "./env.js";
import { pipeline } from "stream";
import { AssumeRoleCommand, STSClient } from "@aws-sdk/client-sts";

const app = express();

const DEFAULT_SERVICE_TYPE = "neptune-db";

interface AwsCredentials {
accessKeyId: string;
secretAccessKey: string;
sessionToken?: string;
expiration?: Date;
}

const credentialCache: { [roleArn: string]: AwsCredentials } = {};

interface DbQueryIncomingHttpHeaders extends IncomingHttpHeaders {
queryid?: string;
"graph-db-connection-url"?: string;
"aws-neptune-region"?: string;
"service-type"?: string;
"db-query-logging-enabled"?: string;
"aws-assume-role-arn"?: string;
}

interface LoggerIncomingHttpHeaders extends IncomingHttpHeaders {
Expand All @@ -34,23 +45,97 @@ interface LoggerIncomingHttpHeaders extends IncomingHttpHeaders {

app.use(requestLoggingMiddleware());

// Function to get IAM headers for AWS4 signing process.
async function getIAMHeaders(options: string | aws4.Request) {
// Function to check if the credentials are valid.
function areCredentialsValid(creds: AwsCredentials): boolean {
return creds.expiration
? new Date(creds.expiration).getTime() - Date.now() > 5 * 60 * 1000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this time range something that is consistent for all users and situations? Perhaps there is a better way to check for expiration instead of checking against a hard coded timespan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually credentials do come with the expiration time and we would need to check against that time if the credentials are valid or not. One other way I could think of is it let it fail and then retry with the new credentials. But I think that is not very efficient approach. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. My immediate thought would be

why-not-both

It surely complicates the logic, but I think it will yield the best outcome.

: true;
}

async function getCredentials(
awsAssumeRoleArn: string | undefined,
region: string | undefined
) {
if (awsAssumeRoleArn !== undefined && awsAssumeRoleArn !== "") {
if (
credentialCache[awsAssumeRoleArn] &&
areCredentialsValid(credentialCache[awsAssumeRoleArn])
) {
return credentialCache[awsAssumeRoleArn];
}
try {
const command = new AssumeRoleCommand({
RoleArn: awsAssumeRoleArn,
RoleSessionName: "GraphExplorerProxyServer",
});
const stsClient = new STSClient({ region: region });
const { Credentials } = await stsClient.send(command);

if (
!Credentials ||
!Credentials.AccessKeyId ||
!Credentials.SecretAccessKey ||
!Credentials.SessionToken
) {
throw new Error("Failed to assume role, no credentials returned");
}

proxyLogger.debug(
"Assumed role successfully using the provided role ARN %s, it will expire at: %s",
awsAssumeRoleArn,
Credentials.Expiration
);
credentialCache[awsAssumeRoleArn] = {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
};

return credentialCache[awsAssumeRoleArn];
} catch (error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible this catch will apply to more than the STSClient. For example, I think it is possible that the aws4.sign() function could throw. If that happens the error message below does not apply.

It might be worth combining the fromNodeProviderChain() credential path and the new STSClient credential path in to a separate function. That allows you to put the try/catch in there.

Then this function becomes something more like (pseudo code):

function getIAMHeaders() {
  const creds = getCredentials();
  if (!creds) {
    throw new Error();
  }
  return aws4.sign(options, {
    accessKeyId: creds.AccessKeyId,
    secretAccessKey: creds.SecretAccessKey,
    ...(creds.SessionToken && {
      sessionToken: creds.SessionToken,
    }),
  });
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I can make that change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in - dfc34dc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I'll take another look at this soon.

proxyLogger.error(
"IAM is enabled but credentials cannot be assumed using the provided role ARN: %s, Error: %s",
awsAssumeRoleArn,
error
);
return undefined;
}
}

const credentialProvider = fromNodeProviderChain();
const creds = await credentialProvider();
if (creds === undefined) {
throw new Error(
proxyLogger.error(
"IAM is enabled but credentials cannot be found on the credential provider chain."
);
return undefined;
}

const headers = aws4.sign(options, {
return {
accessKeyId: creds.accessKeyId,
secretAccessKey: creds.secretAccessKey,
sessionToken: creds.sessionToken,
};
}

// Function to get IAM headers for AWS4 signing process.
async function getIAMHeaders(
options: string | aws4.Request,
region: string | undefined,
awsAssumeRoleArn: string | undefined
) {
const creds = await getCredentials(awsAssumeRoleArn, region);
if (!creds) {
throw new Error(
"IAM is enabled but credentials cannot be found or assumed."
);
}
return aws4.sign(options, {
accessKeyId: creds.accessKeyId,
secretAccessKey: creds.secretAccessKey,
...(creds.sessionToken && { sessionToken: creds.sessionToken }),
});

return headers;
}

// Function to retry fetch requests with exponential backoff.
Expand All @@ -60,20 +145,25 @@ const retryFetch = async (
isIamEnabled: boolean,
region: string | undefined,
serviceType: string,
awsAssumeRoleArn: string | undefined,
retryDelay = 10000,
refetchMaxRetries = 1
) => {
for (let i = 0; i < refetchMaxRetries; i++) {
if (isIamEnabled) {
const data = await getIAMHeaders({
host: url.hostname,
port: url.port,
path: url.pathname + url.search,
service: serviceType,
const data = await getIAMHeaders(
{
host: url.hostname,
port: url.port,
path: url.pathname + url.search,
service: serviceType,
region,
method: options.method,
body: options.body ?? undefined,
},
region,
method: options.method,
body: options.body ?? undefined,
});
awsAssumeRoleArn
);

options = {
host: url.hostname,
Expand Down Expand Up @@ -130,15 +220,17 @@ async function fetchData(
options: RequestInit,
isIamEnabled: boolean,
region: string | undefined,
serviceType: string
serviceType: string,
awsAssumeRoleArn: string | undefined
) {
try {
const response = await retryFetch(
new URL(url),
options,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);

// Set the headers from the fetch response to the client response
Expand Down Expand Up @@ -201,6 +293,7 @@ app.post("/sparql", (req, res, next) => {
const serviceType = isIamEnabled
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE)
: "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

/// Function to cancel long running queries if the client disappears before completion
async function cancelQuery() {
Expand All @@ -221,7 +314,8 @@ app.post("/sparql", (req, res, next) => {
},
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
} catch (err) {
// Not really an error
Expand Down Expand Up @@ -275,7 +369,8 @@ app.post("/sparql", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand All @@ -293,6 +388,7 @@ app.post("/gremlin", (req, res, next) => {
const serviceType = isIamEnabled
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE)
: "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

// Validate the input before making any external calls.
const queryString = req.body.query;
Expand Down Expand Up @@ -320,7 +416,8 @@ app.post("/gremlin", (req, res, next) => {
{ method: "GET" },
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
} catch (err) {
// Not really an error
Expand Down Expand Up @@ -360,7 +457,8 @@ app.post("/gremlin", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand Down Expand Up @@ -398,6 +496,7 @@ app.post("/openCypher", (req, res, next) => {
const serviceType = isIamEnabled
? (headers["service-type"] ?? DEFAULT_SERVICE_TYPE)
: "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

return fetchData(
res,
Expand All @@ -406,7 +505,8 @@ app.post("/openCypher", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand All @@ -424,6 +524,7 @@ app.get("/summary", (req, res, next) => {
};

const region = isIamEnabled ? headers["aws-neptune-region"] : "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

fetchData(
res,
Expand All @@ -432,7 +533,8 @@ app.get("/summary", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand All @@ -450,6 +552,7 @@ app.get("/pg/statistics/summary", (req, res, next) => {
};

const region = isIamEnabled ? headers["aws-neptune-region"] : "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

fetchData(
res,
Expand All @@ -458,7 +561,8 @@ app.get("/pg/statistics/summary", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand All @@ -476,6 +580,7 @@ app.get("/rdf/statistics/summary", (req, res, next) => {
};

const region = isIamEnabled ? headers["aws-neptune-region"] : "";
const awsAssumeRoleArn = isIamEnabled ? headers["aws-assume-role-arn"] : "";

fetchData(
res,
Expand All @@ -484,7 +589,8 @@ app.get("/rdf/statistics/summary", (req, res, next) => {
requestOptions,
isIamEnabled,
region,
serviceType
serviceType,
awsAssumeRoleArn
);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function getAuthHeaders(
if (connection?.awsAuthEnabled) {
headers["aws-neptune-region"] = connection.awsRegion || "";
headers["service-type"] = connection.serviceType || DEFAULT_SERVICE_TYPE;
headers["aws-assume-role-arn"] = connection.awsAssumeRoleArn || "";
}

return { ...headers, ...typeHeaders };
Expand Down
1 change: 1 addition & 0 deletions packages/graph-explorer/src/core/connector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const activeConnectionSelector = equalSelector({
"graphDbUrl",
"awsAuthEnabled",
"awsRegion",
"awsAssumeRoleArn",
"fetchTimeoutMs",
"nodeExpansionLimit",
] as (keyof ConnectionConfig)[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type ConnectionForm = {
awsAuthEnabled?: boolean;
serviceType?: NeptuneServiceType;
awsRegion?: string;
awsAssumeRoleArn?: string;
fetchTimeoutEnabled: boolean;
fetchTimeoutMs?: number;
nodeExpansionLimitEnabled: boolean;
Expand Down Expand Up @@ -72,6 +73,7 @@ function mapToConnection(data: Required<ConnectionForm>): ConnectionConfig {
awsAuthEnabled: data.awsAuthEnabled,
serviceType: data.serviceType,
awsRegion: data.awsRegion,
awsAssumeRoleArn: data.awsAssumeRoleArn,
fetchTimeoutMs: data.fetchTimeoutEnabled ? data.fetchTimeoutMs : undefined,
nodeExpansionLimit: data.nodeExpansionLimitEnabled
? data.nodeExpansionLimit
Expand Down Expand Up @@ -164,6 +166,7 @@ const CreateConnection = ({
awsAuthEnabled: initialData?.awsAuthEnabled || false,
serviceType: initialData?.serviceType || "neptune-db",
awsRegion: initialData?.awsRegion || "",
awsAssumeRoleArn: initialData?.awsAssumeRoleArn || "",
fetchTimeoutEnabled: initialData?.fetchTimeoutEnabled || false,
fetchTimeoutMs: initialData?.fetchTimeoutMs,
nodeExpansionLimitEnabled: initialData?.nodeExpansionLimitEnabled || false,
Expand Down Expand Up @@ -335,6 +338,22 @@ const CreateConnection = ({
onValueChange={onFormChange("serviceType")}
/>
</FormItem>
<FormItem>
<Label>
AWS Assume Role ARN
<InfoTooltip>
ARN of the role that the proxy-server should assume to sign
requests. This is only required if the connector is running
outside of the AWS account that hosts the Neptune resources.
</InfoTooltip>
</Label>
<InputField
data-autofocus={true}
value={form.awsAssumeRoleArn}
onChange={onFormChange("awsAssumeRoleArn")}
placeholder="arn:aws:iam::aws-account-no:role/role-name"
/>
</FormItem>
</>
)}
<FormItem>
Expand Down
Loading