Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(NODE-6141): allow custom aws sdk config #4373

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
import * as net from 'net';

import { deserialize, type Document, serialize } from '../bson';
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
import { type CommandOptions, type ProxyOptions } from '../cmap/connection';
import { kDecorateResult } from '../constants';
import { getMongoDBClientEncryption } from '../deps';
Expand Down Expand Up @@ -104,6 +105,8 @@ export interface AutoEncryptionOptions {
proxyOptions?: ProxyOptions;
/** The TLS options to use connecting to the KMS provider */
tlsOptions?: CSFLEKMSTlsOptions;
/** Optional custom credential provider to use for KMS requests. */
awsCredentialProvider?: AWSCredentialProvider;
}

/**
Expand Down Expand Up @@ -153,6 +156,7 @@ export class AutoEncrypter {
_kmsProviders: KMSProviders;
_bypassMongocryptdAndCryptShared: boolean;
_contextCounter: number;
_awsCredentialProvider?: AWSCredentialProvider;

_mongocryptdManager?: MongocryptdManager;
_mongocryptdClient?: MongoClient;
Expand Down Expand Up @@ -237,6 +241,7 @@ export class AutoEncrypter {
this._proxyOptions = options.proxyOptions || {};
this._tlsOptions = options.tlsOptions || {};
this._kmsProviders = options.kmsProviders || {};
this._awsCredentialProvider = options.awsCredentialProvider;

const mongoCryptOptions: MongoCryptOptions = {
cryptoCallbacks
Expand Down Expand Up @@ -438,7 +443,7 @@ export class AutoEncrypter {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
}

/**
Expand Down
8 changes: 7 additions & 1 deletion src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
type UUID
} from '../bson';
import { type AnyBulkWriteOperation, type BulkWriteResult } from '../bulk/common';
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
import { type ProxyOptions } from '../cmap/connection';
import { type Collection } from '../collection';
import { type FindCursor } from '../cursor/find_cursor';
Expand Down Expand Up @@ -81,6 +82,9 @@ export class ClientEncryption {
/** @internal */
_mongoCrypt: MongoCrypt;

/** @internal */
_awsCredentialProvider?: AWSCredentialProvider;

/** @internal */
static getMongoCrypt(): MongoCryptConstructor {
const encryption = getMongoDBClientEncryption();
Expand Down Expand Up @@ -125,6 +129,8 @@ export class ClientEncryption {
this._kmsProviders = options.kmsProviders || {};
const { timeoutMS } = resolveTimeoutOptions(client, options);
this._timeoutMS = timeoutMS;
this._awsCredentialProvider =
client.options.credentials?.mechanismProperties.AWS_CREDENTIAL_PROVIDER;

if (options.keyVaultNamespace == null) {
throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`');
Expand Down Expand Up @@ -712,7 +718,7 @@ export class ClientEncryption {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
}

static get libmongocryptVersion() {
Expand Down
12 changes: 9 additions & 3 deletions src/client-side-encryption/providers/aws.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider
} from '../../cmap/auth/aws_temporary_credentials';
import { type KMSProviders } from '.';

/**
* @internal
*/
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider();
export async function loadAWSCredentials(
kmsProviders: KMSProviders,
provider?: AWSCredentialProvider
): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider(provider);

// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
Expand Down
8 changes: 6 additions & 2 deletions src/client-side-encryption/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Binary } from '../../bson';
import { type AWSCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import { loadAWSCredentials } from './aws';
import { loadAzureCredentials } from './azure';
import { loadGCPCredentials } from './gcp';
Expand Down Expand Up @@ -176,11 +177,14 @@ export function isEmptyCredentials(
*
* @internal
*/
export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
export async function refreshKMSCredentials(
kmsProviders: KMSProviders,
awsProvider?: AWSCredentialProvider
): Promise<KMSProviders> {
let finalKMSProviders = kmsProviders;

if (isEmptyCredentials('aws', kmsProviders)) {
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
finalKMSProviders = await loadAWSCredentials(finalKMSProviders, awsProvider);
}

if (isEmptyCredentials('gcp', kmsProviders)) {
Expand Down
18 changes: 17 additions & 1 deletion src/cmap/auth/aws_temporary_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ export interface AWSTempCredentials {
Expiration?: Date;
}

/** @public **/
export type AWSCredentialProvider = () => Promise<AWSCredentials>;

/**
* @internal
*
Expand All @@ -41,7 +44,20 @@ export abstract class AWSTemporaryCredentialProvider {

/** @internal */
export class AWSSDKCredentialProvider extends AWSTemporaryCredentialProvider {
private _provider?: () => Promise<AWSCredentials>;
private _provider?: AWSCredentialProvider;

/**
* Create the SDK credentials provider.
* @param credentialsProvider - The credentials provider.
*/
constructor(credentialsProvider?: AWSCredentialProvider) {
super();

if (credentialsProvider) {
this._provider = credentialsProvider;
}
}

/**
* The AWS SDK caches credentials automatically and handles refresh when the credentials have expired.
* To ensure this occurs, we need to cache the `provider` returned by the AWS sdk and re-use it when fetching credentials.
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/auth/mongo_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
MongoInvalidArgumentError,
MongoMissingCredentialsError
} from '../../error';
import type { AWSCredentialProvider } from './aws_temporary_credentials';
import { GSSAPICanonicalizationValue } from './gssapi';
import type { OIDCCallbackFunction } from './mongodb_oidc';
import { AUTH_MECHS_AUTH_SRC_EXTERNAL, AuthMechanism } from './providers';
Expand Down Expand Up @@ -68,6 +69,8 @@ export interface AuthMechanismProperties extends Document {
ALLOWED_HOSTS?: string[];
/** The resource token for OIDC auth in Azure and GCP. */
TOKEN_RESOURCE?: string;
/** A custom AWS credential provider to use. */
AWS_CREDENTIAL_PROVIDER?: AWSCredentialProvider;
}

/** @public */
Expand Down
5 changes: 3 additions & 2 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
import { ByteUtils, maxWireVersion, ns, randomBytes } from '../../utils';
import { type AuthContext, AuthProvider } from './auth_provider';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider,
type AWSTempCredentials,
AWSTemporaryCredentialProvider,
Expand All @@ -34,11 +35,11 @@ interface AWSSaslContinuePayload {

export class MongoDBAWS extends AuthProvider {
private credentialFetcher: AWSTemporaryCredentialProvider;
constructor() {
constructor(credentialProvider?: AWSCredentialProvider) {
super();

this.credentialFetcher = AWSTemporaryCredentialProvider.isAWSSDKInstalled
? new AWSSDKCredentialProvider()
? new AWSSDKCredentialProvider(credentialProvider)
: new LegacyAWSTemporaryCredentialProvider();
}

Expand Down
2 changes: 1 addition & 1 deletion src/deps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function getZstdLibrary(): ZStandardLib | { kModuleError: MongoMissingDep
}

/**
* @internal
* @public
* Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts),
* the return type of the aws-sdk's `fromNodeProviderChain().provider()`.
*/
Expand Down
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ export { ReadPreferenceMode } from './read_preference';
export { ServerType, TopologyType } from './sdam/common';

// Helper classes
export type { AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
export type { AWSCredentials } from './deps';
export { ReadConcern } from './read_concern';
export { ReadPreference } from './read_preference';
export { WriteConcern } from './write_concern';

// events
export {
CommandFailedEvent,
Expand Down
10 changes: 8 additions & 2 deletions src/mongo_client_auth_providers.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { type AuthProvider } from './cmap/auth/auth_provider';
import { type AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
import { GSSAPI } from './cmap/auth/gssapi';
import { type AuthMechanismProperties } from './cmap/auth/mongo_credentials';
import { MongoDBAWS } from './cmap/auth/mongodb_aws';
Expand All @@ -13,8 +14,11 @@ import { X509 } from './cmap/auth/x509';
import { MongoInvalidArgumentError } from './error';

/** @internal */
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (workflow?: Workflow) => AuthProvider>([
[AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()],
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (param?: any) => AuthProvider>([
[
AuthMechanism.MONGODB_AWS,
(credentialProvider?: AWSCredentialProvider) => new MongoDBAWS(credentialProvider)
],
[
AuthMechanism.MONGODB_CR,
() => {
Expand Down Expand Up @@ -65,6 +69,8 @@ export class MongoClientAuthProviders {
let provider;
if (name === AuthMechanism.MONGODB_OIDC) {
provider = providerFunction(this.getWorkflow(authMechanismProperties));
} else if (name === AuthMechanism.MONGODB_AWS) {
provider = providerFunction(authMechanismProperties.AWS_CREDENTIAL_PROVIDER);
} else {
provider = providerFunction();
}
Expand Down
58 changes: 54 additions & 4 deletions test/integration/auth/mongodb_aws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,34 @@ describe('MONGODB-AWS', function () {
});
});

context('when user supplies a credentials provider', function () {
beforeEach(function () {
if (!awsSdkPresent) {
this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed';
return this.skip();
}
});

it('authenticates with a user provided credentials provider', async function () {
// @ts-expect-error We intentionally access a protected variable.
const credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
client = this.configuration.newClient(process.env.MONGODB_URI, {
authMechanismProperties: {
AWS_CREDENTIAL_PROVIDER: credentialProvider.fromNodeProviderChain()
}
});

const result = await client
.db('aws')
.collection('aws_test')
.estimatedDocumentCount()
.catch(error => error);

expect(result).to.not.be.instanceOf(MongoServerError);
expect(result).to.be.a('number');
});
});

it('should allow empty string in authMechanismProperties.AWS_SESSION_TOKEN to override AWS_SESSION_TOKEN environment variable', function () {
client = this.configuration.newClient(this.configuration.url(), {
authMechanismProperties: { AWS_SESSION_TOKEN: '' }
Expand Down Expand Up @@ -426,11 +454,33 @@ describe('AWS KMS Credential Fetching', function () {
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
context('when a credential provider is not providered', function () {
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});
});

context('when a credential provider is provided', function () {
let credentialProvider;

beforeEach(function () {
// @ts-expect-error We intentionally access a protected variable.
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
});

it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials(
{ aws: {} },
credentialProvider.fromNodeProviderChain()
);

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});
});

it('does not return any extra keys for the `aws` credential provider', async function () {
Expand Down