diff --git a/src/MongoDB.Driver/Core/Operations/BulkMixedWriteOperation.cs b/src/MongoDB.Driver/Core/Operations/BulkMixedWriteOperation.cs index efe28a3f4af..d18e1339069 100644 --- a/src/MongoDB.Driver/Core/Operations/BulkMixedWriteOperation.cs +++ b/src/MongoDB.Driver/Core/Operations/BulkMixedWriteOperation.cs @@ -139,10 +139,9 @@ public WriteConcern WriteConcern public BulkWriteOperationResult Execute(OperationContext operationContext, IWriteBinding binding) { using (BeginOperation()) - using (var context = RetryableWriteContext.Create(operationContext, binding, _retryRequested)) + using (var context = RetryableWriteContext.Create(operationContext, binding, IsOperationRetryable())) { EnsureHintIsSupportedIfAnyRequestHasHint(); - context.DisableRetriesIfAnyWriteRequestIsNotRetryable(_requests); var helper = new BatchHelper(_requests, _isOrdered, _writeConcern); foreach (var batch in helper.GetBatches()) { @@ -155,10 +154,9 @@ public BulkWriteOperationResult Execute(OperationContext operationContext, IWrit public async Task ExecuteAsync(OperationContext operationContext, IWriteBinding binding) { using (BeginOperation()) - using (var context = await RetryableWriteContext.CreateAsync(operationContext, binding, _retryRequested).ConfigureAwait(false)) + using (var context = await RetryableWriteContext.CreateAsync(operationContext, binding, IsOperationRetryable()).ConfigureAwait(false)) { EnsureHintIsSupportedIfAnyRequestHasHint(); - context.DisableRetriesIfAnyWriteRequestIsNotRetryable(_requests); var helper = new BatchHelper(_requests, _isOrdered, _writeConcern); foreach (var batch in helper.GetBatches()) { @@ -168,6 +166,9 @@ public async Task ExecuteAsync(OperationContext operat } } + private bool IsOperationRetryable() + => _retryRequested && _requests.All(r => r.IsRetryable()); + private IDisposable BeginOperation() => // Execution starts with the first request EventContext.BeginOperation(null, _requests.FirstOrDefault()?.RequestType.ToString().ToLower()); diff --git a/src/MongoDB.Driver/Core/Operations/BulkUnmixedWriteOperationBase.cs b/src/MongoDB.Driver/Core/Operations/BulkUnmixedWriteOperationBase.cs index 03ad4ad3e3c..0ec4cb33f69 100644 --- a/src/MongoDB.Driver/Core/Operations/BulkUnmixedWriteOperationBase.cs +++ b/src/MongoDB.Driver/Core/Operations/BulkUnmixedWriteOperationBase.cs @@ -129,9 +129,8 @@ public BulkWriteOperationResult Execute(OperationContext operationContext, Retry public BulkWriteOperationResult Execute(OperationContext operationContext, IWriteBinding binding) { using (BeginOperation()) - using (var context = RetryableWriteContext.Create(operationContext, binding, _retryRequested)) + using (var context = RetryableWriteContext.Create(operationContext, binding, IsOperationRetryable())) { - context.DisableRetriesIfAnyWriteRequestIsNotRetryable(_requests); return Execute(operationContext, context); } } @@ -146,9 +145,8 @@ public Task ExecuteAsync(OperationContext operationCon public async Task ExecuteAsync(OperationContext operationContext, IWriteBinding binding) { using (BeginOperation()) - using (var context = await RetryableWriteContext.CreateAsync(operationContext, binding, _retryRequested).ConfigureAwait(false)) + using (var context = await RetryableWriteContext.CreateAsync(operationContext, binding, IsOperationRetryable()).ConfigureAwait(false)) { - context.DisableRetriesIfAnyWriteRequestIsNotRetryable(_requests); return await ExecuteAsync(operationContext, context).ConfigureAwait(false); } } @@ -159,6 +157,9 @@ public async Task ExecuteAsync(OperationContext operat protected abstract bool RequestHasHint(TWriteRequest request); // private methods + private bool IsOperationRetryable() + => _retryRequested && _requests.All(r => r.IsRetryable()); + private IDisposable BeginOperation() => EventContext.BeginOperation(null, _requests.FirstOrDefault()?.RequestType.ToString().ToLower()); diff --git a/src/MongoDB.Driver/Core/Operations/DeleteRequest.cs b/src/MongoDB.Driver/Core/Operations/DeleteRequest.cs index 0eef8967778..78444e3f055 100644 --- a/src/MongoDB.Driver/Core/Operations/DeleteRequest.cs +++ b/src/MongoDB.Driver/Core/Operations/DeleteRequest.cs @@ -14,7 +14,6 @@ */ using MongoDB.Bson; -using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Operations @@ -36,6 +35,6 @@ public DeleteRequest(BsonDocument filter) public int Limit { get; init; } // public methods - public override bool IsRetryable(ConnectionDescription connectionDescription) => Limit != 0; + public override bool IsRetryable() => Limit != 0; } } diff --git a/src/MongoDB.Driver/Core/Operations/InsertRequest.cs b/src/MongoDB.Driver/Core/Operations/InsertRequest.cs index 90a3ef331ff..28c2c820f9b 100644 --- a/src/MongoDB.Driver/Core/Operations/InsertRequest.cs +++ b/src/MongoDB.Driver/Core/Operations/InsertRequest.cs @@ -14,7 +14,6 @@ */ using MongoDB.Bson; -using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Operations @@ -32,6 +31,6 @@ public InsertRequest(BsonDocument document) public BsonDocument Document { get; } // public methods - public override bool IsRetryable(ConnectionDescription connectionDescription) => true; + public override bool IsRetryable() => true; } } diff --git a/src/MongoDB.Driver/Core/Operations/RetryabilityHelper.cs b/src/MongoDB.Driver/Core/Operations/RetryabilityHelper.cs index fd8f735dff7..0ccde3f9c7d 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryabilityHelper.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryabilityHelper.cs @@ -1,4 +1,4 @@ -/* Copyright 2018-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -135,20 +135,17 @@ public static bool IsResumableChangeStreamException(Exception exception, int max { return exception is MongoException mongoException ? mongoException.HasErrorLabel(ResumableChangeStreamErrorLabel) : false; } - else + + if (exception is MongoCommandException commandException) { - var commandException = exception as MongoCommandException; - if (commandException != null) + var code = (ServerErrorCode)commandException.Code; + if (__resumableChangeStreamErrorCodes.Contains(code)) { - var code = (ServerErrorCode)commandException.Code; - if (__resumableChangeStreamErrorCodes.Contains(code)) - { - return true; - } + return true; } - - return __resumableChangeStreamExceptions.Contains(exception.GetType()); } + + return __resumableChangeStreamExceptions.Contains(exception.GetType()); } /// diff --git a/src/MongoDB.Driver/Core/Operations/RetryableReadContext.cs b/src/MongoDB.Driver/Core/Operations/RetryableReadContext.cs index 18d62e4b4a5..d1f8f36c15f 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableReadContext.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableReadContext.cs @@ -1,4 +1,4 @@ -/* Copyright 2019-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,11 @@ */ using System; +using System.Collections.Generic; using System.Threading.Tasks; using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.Servers; namespace MongoDB.Driver.Core.Operations { @@ -29,20 +31,16 @@ public static RetryableReadContext Create(OperationContext operationContext, IRe var context = new RetryableReadContext(binding, retryRequested); try { - context.Initialize(operationContext); - - ChannelPinningHelper.PinChannellIfRequired( - context.ChannelSource, - context.Channel, - context.Binding.Session); - - return context; + context.AcquireOrReplaceChannel(operationContext, null); } catch { context.Dispose(); throw; } + + ChannelPinningHelper.PinChannellIfRequired(context.ChannelSource, context.Channel, context.Binding.Session); + return context; } public static async Task CreateAsync(OperationContext operationContext, IReadBinding binding, bool retryRequested) @@ -50,20 +48,16 @@ public static async Task CreateAsync(OperationContext oper var context = new RetryableReadContext(binding, retryRequested); try { - await context.InitializeAsync(operationContext).ConfigureAwait(false); - - ChannelPinningHelper.PinChannellIfRequired( - context.ChannelSource, - context.Channel, - context.Binding.Session); - - return context; + await context.AcquireOrReplaceChannelAsync(operationContext, null).ConfigureAwait(false); } catch { context.Dispose(); throw; } + + ChannelPinningHelper.PinChannellIfRequired(context.ChannelSource, context.Channel, context.Binding.Session); + return context; } #endregion @@ -96,14 +90,52 @@ public void Dispose() } } - public void ReplaceChannel(IChannelHandle channel) + public void AcquireOrReplaceChannel(OperationContext operationContext, IReadOnlyCollection deprioritizedServers) + { + var attempt = 1; + while (true) + { + operationContext.ThrowIfTimedOutOrCanceled(); + ReplaceChannelSource(Binding.GetReadChannelSource(operationContext, deprioritizedServers)); + try + { + ReplaceChannel(ChannelSource.GetChannel(operationContext)); + return; + } + catch (Exception ex) when (RetryableReadOperationExecutor.ShouldConnectionAcquireBeRetried(operationContext, this, ex, attempt)) + { + attempt++; + } + } + } + + public async Task AcquireOrReplaceChannelAsync(OperationContext operationContext, IReadOnlyCollection deprioritizedServers) + { + var attempt = 1; + while (true) + { + operationContext.ThrowIfTimedOutOrCanceled(); + ReplaceChannelSource(await Binding.GetReadChannelSourceAsync(operationContext, deprioritizedServers).ConfigureAwait(false)); + try + { + ReplaceChannel(await ChannelSource.GetChannelAsync(operationContext).ConfigureAwait(false)); + return; + } + catch (Exception ex) when (RetryableReadOperationExecutor.ShouldConnectionAcquireBeRetried(operationContext, this, ex, attempt)) + { + attempt++; + } + } + } + + private void ReplaceChannel(IChannelHandle channel) { Ensure.IsNotNull(channel, nameof(channel)); _channel?.Dispose(); _channel = channel; } - public void ReplaceChannelSource(IChannelSourceHandle channelSource) + private void ReplaceChannelSource(IChannelSourceHandle channelSource) { Ensure.IsNotNull(channelSource, nameof(channelSource)); _channelSource?.Dispose(); @@ -111,35 +143,5 @@ public void ReplaceChannelSource(IChannelSourceHandle channelSource) _channelSource = channelSource; _channel = null; } - - private void Initialize(OperationContext operationContext) - { - _channelSource = _binding.GetReadChannelSource(operationContext); - - try - { - _channel = _channelSource.GetChannel(operationContext); - } - catch (Exception ex) when (RetryableReadOperationExecutor.ShouldConnectionAcquireBeRetried(this, ex)) - { - ReplaceChannelSource(_binding.GetReadChannelSource(operationContext)); - ReplaceChannel(_channelSource.GetChannel(operationContext)); - } - } - - private async Task InitializeAsync(OperationContext operationContext) - { - _channelSource = await _binding.GetReadChannelSourceAsync(operationContext).ConfigureAwait(false); - - try - { - _channel = await _channelSource.GetChannelAsync(operationContext).ConfigureAwait(false); - } - catch (Exception ex) when (RetryableReadOperationExecutor.ShouldConnectionAcquireBeRetried(this, ex)) - { - ReplaceChannelSource(await _binding.GetReadChannelSourceAsync(operationContext).ConfigureAwait(false)); - ReplaceChannel(await _channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)); - } - } } } diff --git a/src/MongoDB.Driver/Core/Operations/RetryableReadOperationExecutor.cs b/src/MongoDB.Driver/Core/Operations/RetryableReadOperationExecutor.cs index d628382e264..77f9479e979 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableReadOperationExecutor.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableReadOperationExecutor.cs @@ -1,4 +1,4 @@ -/* Copyright 2019-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,121 +14,115 @@ */ using System; +using System.Collections.Generic; using System.Threading.Tasks; -using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Servers; namespace MongoDB.Driver.Core.Operations { internal static class RetryableReadOperationExecutor { // public static methods - public static TResult Execute(OperationContext operationContext, IRetryableReadOperation operation, IReadBinding binding, bool retryRequested) - { - using (var context = RetryableReadContext.Create(operationContext, binding, retryRequested)) - { - return Execute(operationContext, operation, context); - } - } - public static TResult Execute(OperationContext operationContext, IRetryableReadOperation operation, RetryableReadContext context) { - if (!ShouldReadBeRetried(context)) - { - return operation.ExecuteAttempt(operationContext, context, attempt: 1, transactionNumber: null); - } + HashSet deprioritizedServers = null; + var attempt = 1; + Exception originalException = null; - Exception originalException; - try - { - return operation.ExecuteAttempt(operationContext, context, attempt: 1, transactionNumber: null); + while (true) // Circle breaking logic based on ShouldRetryOperation method, see the catch block below. + { + operationContext.ThrowIfTimedOutOrCanceled(); + var server = context.ChannelSource.ServerDescription; + try + { + return operation.ExecuteAttempt(operationContext, context, attempt, transactionNumber: null); + } + catch (Exception ex) + { + if (!ShouldRetryOperation(operationContext, context, ex, attempt)) + { + throw originalException ?? ex; + } - } - catch (Exception ex) when (RetryabilityHelper.IsRetryableReadException(ex)) - { - originalException = ex; - } + originalException ??= ex; + } - try - { - context.ReplaceChannelSource(context.Binding.GetReadChannelSource(operationContext, new[] { context.ChannelSource.ServerDescription })); - context.ReplaceChannel(context.ChannelSource.GetChannel(operationContext)); - } - catch - { - throw originalException; - } + deprioritizedServers ??= new HashSet(); + deprioritizedServers.Add(server); - try - { - return operation.ExecuteAttempt(operationContext, context, attempt: 2, transactionNumber: null); - } - catch (Exception ex) when (ShouldThrowOriginalException(ex)) - { - throw originalException; - } - } + try + { + context.AcquireOrReplaceChannel(operationContext, deprioritizedServers); + } + catch + { + throw originalException; + } - public static async Task ExecuteAsync(OperationContext operationContext, IRetryableReadOperation operation, IReadBinding binding, bool retryRequested) - { - using (var context = await RetryableReadContext.CreateAsync(operationContext, binding, retryRequested).ConfigureAwait(false)) - { - return await ExecuteAsync(operationContext, operation, context).ConfigureAwait(false); + attempt++; } } public static async Task ExecuteAsync(OperationContext operationContext, IRetryableReadOperation operation, RetryableReadContext context) { - if (!ShouldReadBeRetried(context)) - { - return await operation.ExecuteAttemptAsync(operationContext, context, attempt: 1, transactionNumber: null).ConfigureAwait(false); - } + HashSet deprioritizedServers = null; + var attempt = 1; + Exception originalException = null; - Exception originalException; - try - { - return await operation.ExecuteAttemptAsync(operationContext, context, attempt: 1, transactionNumber: null).ConfigureAwait(false); - } - catch (Exception ex) when (RetryabilityHelper.IsRetryableReadException(ex)) - { - originalException = ex; - } + while (true) // Circle breaking logic based on ShouldRetryOperation method, see the catch block below. + { + operationContext.ThrowIfTimedOutOrCanceled(); + var server = context.ChannelSource.ServerDescription; + try + { + return await operation.ExecuteAttemptAsync(operationContext, context, attempt, transactionNumber: null).ConfigureAwait(false); + } + catch (Exception ex) + { + if (!ShouldRetryOperation(operationContext, context, ex, attempt)) + { + throw originalException ?? ex; + } - try - { - context.ReplaceChannelSource(context.Binding.GetReadChannelSource(operationContext, new[] { context.ChannelSource.ServerDescription })); - context.ReplaceChannel(context.ChannelSource.GetChannel(operationContext)); - } - catch - { - throw originalException; - } + originalException ??= ex; + } - try - { - return await operation.ExecuteAttemptAsync(operationContext, context, attempt: 2, transactionNumber: null).ConfigureAwait(false); - } - catch (Exception ex) when (ShouldThrowOriginalException(ex)) - { - throw originalException; + deprioritizedServers ??= new HashSet(); + deprioritizedServers.Add(server); + + try + { + await context.AcquireOrReplaceChannelAsync(operationContext, deprioritizedServers).ConfigureAwait(false); + } + catch + { + throw originalException; + } + + attempt++; } } - public static bool ShouldConnectionAcquireBeRetried(RetryableReadContext context, Exception ex) + public static bool ShouldConnectionAcquireBeRetried(OperationContext operationContext, RetryableReadContext context, Exception exception, int attempt) { - // According the spec error during handshake should be handle according to RetryableReads logic - var innerException = ex is MongoAuthenticationException mongoAuthenticationException ? mongoAuthenticationException.InnerException : ex; - return context.RetryRequested && !context.Binding.Session.IsInTransaction && RetryabilityHelper.IsRetryableReadException(innerException); + var innerException = exception is MongoAuthenticationException mongoAuthenticationException ? mongoAuthenticationException.InnerException : exception; + return ShouldRetryOperation(operationContext, context, innerException, attempt); } // private static methods - private static bool ShouldReadBeRetried(RetryableReadContext context) + private static bool ShouldRetryOperation(OperationContext operationContext, RetryableReadContext context, Exception exception, int attempt) { - return context.RetryRequested && !context.Binding.Session.IsInTransaction; - } + if (!context.RetryRequested || context.Binding.Session.IsInTransaction) + { + return false; + } - private static bool ShouldThrowOriginalException(Exception retryException) - { - return retryException is MongoException && !(retryException is MongoConnectionException); + if (!RetryabilityHelper.IsRetryableReadException(exception)) + { + return false; + } + + return operationContext.IsRootContextTimeoutConfigured() || attempt < 2; } } } diff --git a/src/MongoDB.Driver/Core/Operations/RetryableWriteContext.cs b/src/MongoDB.Driver/Core/Operations/RetryableWriteContext.cs index 9c15d1e9bfc..cbf64188d99 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableWriteContext.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableWriteContext.cs @@ -1,4 +1,4 @@ -/* Copyright 2017-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,10 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Threading.Tasks; using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.Servers; namespace MongoDB.Driver.Core.Operations { @@ -31,20 +31,16 @@ public static RetryableWriteContext Create(OperationContext operationContext, IW var context = new RetryableWriteContext(binding, retryRequested); try { - context.Initialize(operationContext); - - ChannelPinningHelper.PinChannellIfRequired( - context.ChannelSource, - context.Channel, - context.Binding.Session); - - return context; + context.AcquireOrReplaceChannel(operationContext, null); } catch { context.Dispose(); throw; } + + ChannelPinningHelper.PinChannellIfRequired(context.ChannelSource, context.Channel, context.Binding.Session); + return context; } public static async Task CreateAsync(OperationContext operationContext, IWriteBinding binding, bool retryRequested) @@ -52,20 +48,16 @@ public static async Task CreateAsync(OperationContext ope var context = new RetryableWriteContext(binding, retryRequested); try { - await context.InitializeAsync(operationContext).ConfigureAwait(false); - - ChannelPinningHelper.PinChannellIfRequired( - context.ChannelSource, - context.Channel, - context.Binding.Session); - - return context; + await context.AcquireOrReplaceChannelAsync(operationContext, null).ConfigureAwait(false); } catch { context.Dispose(); throw; } + + ChannelPinningHelper.PinChannellIfRequired(context.ChannelSource, context.Channel, context.Binding.Session); + return context; } #endregion @@ -88,35 +80,64 @@ public RetryableWriteContext(IWriteBinding binding, bool retryRequested) public IChannelSourceHandle ChannelSource => _channelSource; public bool RetryRequested => _retryRequested; - public void DisableRetriesIfAnyWriteRequestIsNotRetryable(IEnumerable requests) + public void Dispose() { - if (_retryRequested) + if (!_disposed) { - if (requests.Any(r => !r.IsRetryable(_channel.ConnectionDescription))) + _channelSource?.Dispose(); + _channel?.Dispose(); + _disposed = true; + } + } + + public void AcquireOrReplaceChannel(OperationContext operationContext, IReadOnlyCollection deprioritizedServers) + { + var attempt = 1; + while (true) + { + operationContext.ThrowIfTimedOutOrCanceled(); + ReplaceChannelSource(Binding.GetWriteChannelSource(operationContext, deprioritizedServers)); + var server = ChannelSource.ServerDescription; + try + { + ReplaceChannel(ChannelSource.GetChannel(operationContext)); + return; + } + catch (Exception ex) when (RetryableWriteOperationExecutor.ShouldConnectionAcquireBeRetried(operationContext, this, server, ex, attempt)) { - _retryRequested = false; + attempt++; } } } - public void Dispose() + public async Task AcquireOrReplaceChannelAsync(OperationContext operationContext, IReadOnlyCollection deprioritizedServers) { - if (!_disposed) + var attempt = 1; + while (true) { - _channelSource?.Dispose(); - _channel?.Dispose(); - _disposed = true; + operationContext.ThrowIfTimedOutOrCanceled(); + ReplaceChannelSource(await Binding.GetWriteChannelSourceAsync(operationContext, deprioritizedServers).ConfigureAwait(false)); + var server = ChannelSource.ServerDescription; + try + { + ReplaceChannel(await ChannelSource.GetChannelAsync(operationContext).ConfigureAwait(false)); + return; + } + catch (Exception ex) when (RetryableWriteOperationExecutor.ShouldConnectionAcquireBeRetried(operationContext, this, server, ex, attempt)) + { + attempt++; + } } } - public void ReplaceChannel(IChannelHandle channel) + private void ReplaceChannel(IChannelHandle channel) { Ensure.IsNotNull(channel, nameof(channel)); _channel?.Dispose(); _channel = channel; } - public void ReplaceChannelSource(IChannelSourceHandle channelSource) + private void ReplaceChannelSource(IChannelSourceHandle channelSource) { Ensure.IsNotNull(channelSource, nameof(channelSource)); _channelSource?.Dispose(); @@ -124,37 +145,5 @@ public void ReplaceChannelSource(IChannelSourceHandle channelSource) _channelSource = channelSource; _channel = null; } - - private void Initialize(OperationContext operationContext) - { - _channelSource = _binding.GetWriteChannelSource(operationContext); - var serverDescription = _channelSource.ServerDescription; - - try - { - _channel = _channelSource.GetChannel(operationContext); - } - catch (Exception ex) when (RetryableWriteOperationExecutor.ShouldConnectionAcquireBeRetried(this, serverDescription, ex)) - { - ReplaceChannelSource(_binding.GetWriteChannelSource(operationContext)); - ReplaceChannel(_channelSource.GetChannel(operationContext)); - } - } - - private async Task InitializeAsync(OperationContext operationContext) - { - _channelSource = await _binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false); - var serverDescription = _channelSource.ServerDescription; - - try - { - _channel = await _channelSource.GetChannelAsync(operationContext).ConfigureAwait(false); - } - catch (Exception ex) when (RetryableWriteOperationExecutor.ShouldConnectionAcquireBeRetried(this, serverDescription, ex)) - { - ReplaceChannelSource(await _binding.GetWriteChannelSourceAsync(operationContext).ConfigureAwait(false)); - ReplaceChannel(await _channelSource.GetChannelAsync(operationContext).ConfigureAwait(false)); - } - } } } diff --git a/src/MongoDB.Driver/Core/Operations/RetryableWriteOperationExecutor.cs b/src/MongoDB.Driver/Core/Operations/RetryableWriteOperationExecutor.cs index f654d614260..ff8568bda36 100644 --- a/src/MongoDB.Driver/Core/Operations/RetryableWriteOperationExecutor.cs +++ b/src/MongoDB.Driver/Core/Operations/RetryableWriteOperationExecutor.cs @@ -1,4 +1,4 @@ -/* Copyright 2017-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ */ using System; +using System.Collections.Generic; using System.Threading.Tasks; using MongoDB.Driver.Core.Bindings; -using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Servers; namespace MongoDB.Driver.Core.Operations @@ -34,44 +34,48 @@ public static TResult Execute(OperationContext operationContext, IRetry public static TResult Execute(OperationContext operationContext, IRetryableWriteOperation operation, RetryableWriteContext context) { - if (!AreRetriesAllowed(operation, context)) - { - return operation.ExecuteAttempt(operationContext, context, 1, null); - } - - var transactionNumber = context.Binding.Session.AdvanceTransactionNumber(); - Exception originalException; - try - { - return operation.ExecuteAttempt(operationContext, context, 1, transactionNumber); - } - catch (Exception ex) when (RetryabilityHelper.IsRetryableWriteException(ex)) - { - originalException = ex; - } - - try - { - context.ReplaceChannelSource(context.Binding.GetWriteChannelSource(operationContext, new[] { context.ChannelSource.ServerDescription })); - context.ReplaceChannel(context.ChannelSource.GetChannel(operationContext)); - } - catch - { - throw originalException; - } - - if (!AreRetryableWritesSupported(context.Channel.ConnectionDescription)) - { - throw originalException; - } - - try - { - return operation.ExecuteAttempt(operationContext, context, 2, transactionNumber); - } - catch (Exception ex) when (ShouldThrowOriginalException(ex)) - { - throw originalException; + HashSet deprioritizedServers = null; + var attempt = 1; + Exception originalException = null; + + long? transactionNumber = AreRetriesAllowed(operation.WriteConcern, context, context.ChannelSource.ServerDescription) ? context.Binding.Session.AdvanceTransactionNumber() : null; + + while (true) // Circle breaking logic based on ShouldRetryOperation method, see the catch block below. + { + operationContext.ThrowIfTimedOutOrCanceled(); + var server = context.ChannelSource.ServerDescription; + try + { + return operation.ExecuteAttempt(operationContext, context, attempt, transactionNumber); + } + catch (Exception ex) + { + if (!ShouldRetryOperation(operationContext, operation.WriteConcern, context, server, ex, attempt)) + { + throw originalException ?? ex; + } + + originalException ??= ex; + } + + deprioritizedServers ??= new HashSet(); + deprioritizedServers.Add(server); + + try + { + context.AcquireOrReplaceChannel(operationContext, deprioritizedServers); + } + catch + { + throw originalException; + } + + if (!AreRetryableWritesSupported(context.ChannelSource.ServerDescription)) + { + throw originalException; + } + + attempt++; } } @@ -85,72 +89,86 @@ public async static Task ExecuteAsync(OperationContext operati public static async Task ExecuteAsync(OperationContext operationContext, IRetryableWriteOperation operation, RetryableWriteContext context) { - if (!AreRetriesAllowed(operation, context)) - { - return await operation.ExecuteAttemptAsync(operationContext, context, 1, null).ConfigureAwait(false); + HashSet deprioritizedServers = null; + var attempt = 1; + Exception originalException = null; + + long? transactionNumber = AreRetriesAllowed(operation.WriteConcern, context, context.ChannelSource.ServerDescription) ? context.Binding.Session.AdvanceTransactionNumber() : null; + + while (true) // Circle breaking logic based on ShouldRetryOperation method, see the catch block below. + { + operationContext.ThrowIfTimedOutOrCanceled(); + var server = context.ChannelSource.ServerDescription; + try + { + return await operation.ExecuteAttemptAsync(operationContext, context, attempt, transactionNumber).ConfigureAwait(false); + } + catch (Exception ex) + { + if (!ShouldRetryOperation(operationContext, operation.WriteConcern, context, server, ex, attempt)) + { + throw originalException ?? ex; + } + + originalException ??= ex; + } + + deprioritizedServers ??= new HashSet(); + deprioritizedServers.Add(server); + + try + { + await context.AcquireOrReplaceChannelAsync(operationContext, deprioritizedServers).ConfigureAwait(false); + } + catch + { + throw originalException; + } + + if (!AreRetryableWritesSupported(context.ChannelSource.ServerDescription)) + { + throw originalException; + } + + attempt++; } + } - var transactionNumber = context.Binding.Session.AdvanceTransactionNumber(); - Exception originalException; - try - { - return await operation.ExecuteAttemptAsync(operationContext, context, 1, transactionNumber).ConfigureAwait(false); - } - catch (Exception ex) when (RetryabilityHelper.IsRetryableWriteException(ex)) + public static bool ShouldConnectionAcquireBeRetried(OperationContext operationContext, RetryableWriteContext context, ServerDescription server, Exception exception, int attempt) + { + if (!DoesContextAllowRetries(context, server)) { - originalException = ex; + return false; } - try - { - context.ReplaceChannelSource(await context.Binding.GetWriteChannelSourceAsync(operationContext, new[] { context.ChannelSource.ServerDescription }).ConfigureAwait(false)); - context.ReplaceChannel(await context.ChannelSource.GetChannelAsync(operationContext).ConfigureAwait(false)); - } - catch + var innerException = exception is MongoAuthenticationException mongoAuthenticationException ? mongoAuthenticationException.InnerException : exception; + // According the spec error during handshake should be handle according to RetryableReads logic + if (!RetryabilityHelper.IsRetryableReadException(innerException)) { - throw originalException; + return false; } - if (!AreRetryableWritesSupported(context.Channel.ConnectionDescription)) - { - throw originalException; - } + return operationContext.IsRootContextTimeoutConfigured() || attempt < 2; + } - try + // private static methods + private static bool ShouldRetryOperation(OperationContext operationContext, WriteConcern writeConcern, RetryableWriteContext context, ServerDescription server, Exception exception, int attempt) + { + if (!AreRetriesAllowed(writeConcern, context, server)) { - return await operation.ExecuteAttemptAsync(operationContext, context, 2, transactionNumber).ConfigureAwait(false); + return false; } - catch (Exception ex) when (ShouldThrowOriginalException(ex)) + + if (!RetryabilityHelper.IsRetryableWriteException(exception)) { - throw originalException; + return false; } - } - - public static bool ShouldConnectionAcquireBeRetried(RetryableWriteContext context, ServerDescription serverDescription, Exception exception) - { - var innerException = exception is MongoAuthenticationException mongoAuthenticationException ? mongoAuthenticationException.InnerException : exception; - // According the spec error during handshake should be handle according to RetryableReads logic - return context.RetryRequested && - AreRetryableWritesSupported(serverDescription) && - context.Binding.Session.Id != null && - !context.Binding.Session.IsInTransaction && - RetryabilityHelper.IsRetryableReadException(innerException); - } - - // private static methods - private static bool AreRetriesAllowed(IRetryableWriteOperation operation, RetryableWriteContext context) - { - return IsOperationAcknowledged(operation) && DoesContextAllowRetries(context); + return operationContext.IsRootContextTimeoutConfigured() || attempt < 2; } - private static bool AreRetryableWritesSupported(ConnectionDescription connectionDescription) - { - var helloResult = connectionDescription.HelloResult; - return - helloResult.ServerType == ServerType.LoadBalanced || - (helloResult.LogicalSessionTimeout != null && helloResult.ServerType != ServerType.Standalone); - } + private static bool AreRetriesAllowed(WriteConcern writeConcern, RetryableWriteContext context, ServerDescription server) + => IsOperationAcknowledged(writeConcern) && DoesContextAllowRetries(context, server); private static bool AreRetryableWritesSupported(ServerDescription serverDescription) { @@ -158,25 +176,14 @@ private static bool AreRetryableWritesSupported(ServerDescription serverDescript (serverDescription.LogicalSessionTimeout != null && serverDescription.Type != ServerType.Standalone); } - private static bool DoesContextAllowRetries(RetryableWriteContext context) - { - return - context.RetryRequested && - AreRetryableWritesSupported(context.Channel.ConnectionDescription) && - context.Binding.Session.Id != null && - !context.Binding.Session.IsInTransaction; - } - - private static bool IsOperationAcknowledged(IRetryableWriteOperation operation) - { - var writeConcern = operation.WriteConcern; - return - writeConcern == null || // null means use server default write concern which implies acknowledged - writeConcern.IsAcknowledged; - } + private static bool DoesContextAllowRetries(RetryableWriteContext context, ServerDescription server) + => context.RetryRequested && + AreRetryableWritesSupported(server) && + context.Binding.Session.Id != null && + !context.Binding.Session.IsInTransaction; - private static bool ShouldThrowOriginalException(Exception retryException) => - retryException == null || - retryException is MongoException && !(retryException is MongoConnectionException || retryException is MongoConnectionPoolPausedException); + private static bool IsOperationAcknowledged(WriteConcern writeConcern) + => writeConcern == null || // null means use server default write concern which implies acknowledged + writeConcern.IsAcknowledged; } } diff --git a/src/MongoDB.Driver/Core/Operations/UpdateRequest.cs b/src/MongoDB.Driver/Core/Operations/UpdateRequest.cs index 88810698d4c..af885a697be 100644 --- a/src/MongoDB.Driver/Core/Operations/UpdateRequest.cs +++ b/src/MongoDB.Driver/Core/Operations/UpdateRequest.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using MongoDB.Bson; -using MongoDB.Driver.Core.Connections; using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Core.Operations @@ -44,7 +43,7 @@ public UpdateRequest(UpdateType updateType, BsonDocument filter, BsonValue updat public UpdateType UpdateType { get; init; } // public methods - public override bool IsRetryable(ConnectionDescription connectionDescription) => !IsMulti; + public override bool IsRetryable() => !IsMulti; // private methods private static BsonValue EnsureUpdateIsValid(BsonValue update, UpdateType updateType) diff --git a/src/MongoDB.Driver/Core/Operations/WriteRequest.cs b/src/MongoDB.Driver/Core/Operations/WriteRequest.cs index 104d525c574..1da745c87f0 100644 --- a/src/MongoDB.Driver/Core/Operations/WriteRequest.cs +++ b/src/MongoDB.Driver/Core/Operations/WriteRequest.cs @@ -13,9 +13,6 @@ * limitations under the License. */ -using System; -using MongoDB.Driver.Core.Connections; - namespace MongoDB.Driver.Core.Operations { internal abstract class WriteRequest @@ -31,6 +28,6 @@ protected WriteRequest(WriteRequestType requestType) public WriteRequestType RequestType { get; init; } // public methods - public abstract bool IsRetryable(ConnectionDescription connectionDescription); + public abstract bool IsRetryable(); } } diff --git a/src/MongoDB.Driver/OperationContext.cs b/src/MongoDB.Driver/OperationContext.cs index c0ccd67919f..d2359de1c9f 100644 --- a/src/MongoDB.Driver/OperationContext.cs +++ b/src/MongoDB.Driver/OperationContext.cs @@ -36,11 +36,12 @@ internal OperationContext(Stopwatch stopwatch, TimeSpan timeout, CancellationTok Stopwatch = stopwatch; Timeout = Ensure.IsInfiniteOrGreaterThanOrEqualToZero(timeout, nameof(timeout)); CancellationToken = cancellationToken; + RootContext = this; } public CancellationToken CancellationToken { get; } - public OperationContext ParentContext { get; private init; } + public OperationContext RootContext { get; private init; } public TimeSpan RemainingTimeout { @@ -159,7 +160,7 @@ public OperationContext WithTimeout(TimeSpan timeout) return new OperationContext(timeout, CancellationToken) { - ParentContext = this + RootContext = RootContext }; } } diff --git a/src/MongoDB.Driver/OperationContextExtensions.cs b/src/MongoDB.Driver/OperationContextExtensions.cs new file mode 100644 index 00000000000..dace4a05bdf --- /dev/null +++ b/src/MongoDB.Driver/OperationContextExtensions.cs @@ -0,0 +1,25 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Threading; + +namespace MongoDB.Driver +{ + internal static class OperationContextExtensions + { + public static bool IsRootContextTimeoutConfigured(this OperationContext operationContext) + => operationContext.RootContext.Timeout != Timeout.InfiniteTimeSpan; + } +} diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs index f866f268877..325eb5536f7 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/ReadCommandOperationTests.cs @@ -1,4 +1,4 @@ -/* Copyright 2016-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * limitations under the License. */ +using System.Collections.Generic; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -320,8 +321,8 @@ private Mock CreateMockReadBinding(ReadPreference readPreference, var mockSession = new Mock(); mockBinding.SetupGet(b => b.ReadPreference).Returns(readPreference); mockBinding.SetupGet(b => b.Session).Returns(mockSession.Object); - mockBinding.Setup(b => b.GetReadChannelSource(It.IsAny())).Returns(channelSource); - mockBinding.Setup(b => b.GetReadChannelSourceAsync(It.IsAny())).Returns(Task.FromResult(channelSource)); + mockBinding.Setup(b => b.GetReadChannelSource(It.IsAny(), It.IsAny>())).Returns(channelSource); + mockBinding.Setup(b => b.GetReadChannelSourceAsync(It.IsAny(), It.IsAny>())).Returns(Task.FromResult(channelSource)); return mockBinding; } diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/RetryableReadOperationExecutorTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/RetryableReadOperationExecutorTests.cs new file mode 100644 index 00000000000..4eec630f34d --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Core/Operations/RetryableReadOperationExecutorTests.cs @@ -0,0 +1,80 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.IO; +using System.Threading; +using MongoDB.Bson.TestHelpers; +using MongoDB.Driver.Core.Bindings; +using MongoDB.Driver.Core.Operations; +using MongoDB.Driver.Core.TestHelpers; +using Moq; +using Xunit; + +namespace MongoDB.Driver.Core.Tests.Core.Operations +{ + public class RetryableReadOperationExecutorTests + { + [Theory] + // No retries if retryRequested == false + [InlineData(false, false, false, true, false, 1)] + [InlineData(false, false, false, true, true, 1)] + // No retries if in transaction + [InlineData(false, true, true, true, false, 1)] + [InlineData(false, true, true, true, true, 1)] + // No retries in non-retriable exception + [InlineData(false, true, false, false, false, 1)] + [InlineData(false, true, false, false, true, 1)] + // No timeout configured - should retry once + [InlineData(true, true, false, true, false, 1)] + [InlineData(false, true, false, true, false, 2)] + // Timeout configured - should retry as many times as possible + [InlineData(true, true, false, true, true, 1)] + [InlineData(true, true, false, true, true, 2)] + [InlineData(true, true, false, true, true, 10)] + public void ShouldRetryOperation_should_return_expected_result( + bool expected, + bool isRetryRequested, + bool isInTransaction, + bool isRetriableException, + bool hasTimeout, + int attempt) + { + var retryableReadContext = CreateSubject(isRetryRequested, isInTransaction); + var exception =CoreExceptionHelper.CreateException(isRetriableException ? nameof(MongoNodeIsRecoveringException) : nameof(IOException)); + var operationContext = new OperationContext(hasTimeout ? TimeSpan.FromSeconds(42) : Timeout.InfiniteTimeSpan, CancellationToken.None); + + var result = RetryableReadOperationExecutorReflector.ShouldRetryOperation(operationContext, retryableReadContext, exception, attempt); + + Assert.Equal(expected, result); + } + + private static RetryableReadContext CreateSubject(bool retryRequested, bool isInTransaction) + { + var sessionMock = new Mock(); + sessionMock.SetupGet(m => m.IsInTransaction).Returns(isInTransaction); + var bindingMock = new Mock(); + bindingMock.SetupGet(m => m.Session).Returns(sessionMock.Object); + return new RetryableReadContext(bindingMock.Object, retryRequested); + } + + private static class RetryableReadOperationExecutorReflector + { + public static bool ShouldRetryOperation(OperationContext operationContext, RetryableReadContext context, Exception exception, int attempt) + => (bool)Reflector.InvokeStatic(typeof(RetryableReadOperationExecutor), nameof(ShouldRetryOperation), operationContext, context, exception, attempt); + } + } +} + diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/RetryableWriteOperationExecutorTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/RetryableWriteOperationExecutorTests.cs index 471c4a0a931..d01706d6128 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/RetryableWriteOperationExecutorTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/RetryableWriteOperationExecutorTests.cs @@ -1,4 +1,4 @@ -/* Copyright 2020-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,17 +13,14 @@ * limitations under the License. */ -using System.Linq; +using System; +using System.Collections.Generic; using System.Net; -using System.Reflection; -using System.Threading; using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.TestHelpers; using MongoDB.Driver.Core.Bindings; using MongoDB.Driver.Core.Clusters; -using MongoDB.Driver.Core.Connections; -using MongoDB.Driver.Core.Misc; using MongoDB.Driver.Core.Operations; using MongoDB.Driver.Core.Servers; using Moq; @@ -36,9 +33,9 @@ public class RetryableWriteOperationExecutorTests [Fact] public void AreRetryableWritesSupportedTest() { - var connectionDescription = CreateConnectionDescription(withLogicalSessionTimeout: false, serviceId: true); + var serverDescription = CreateServerDescription(withLogicalSessionTimeout: false, isLoadBalanced: true); - var result = RetryableWriteOperationExecutorReflector.AreRetryableWritesSupported(connectionDescription); + var result = RetryableWriteOperationExecutorReflector.AreRetryableWritesSupported(serverDescription); result.Should().BeTrue(); } @@ -69,24 +66,20 @@ public void DoesContextAllowRetries_should_return_expected_result( { var context = CreateContext(retryRequested, areRetryableWritesSupported, hasSessionId, isInTransaction); - var result = RetryableWriteOperationExecutorReflector.DoesContextAllowRetries(context); + var result = RetryableWriteOperationExecutorReflector.DoesContextAllowRetries(context, context.ChannelSource.ServerDescription); result.Should().Be(expectedResult); } [Theory] - [InlineData(false, false, true)] - [InlineData(false, true, true)] - [InlineData(true, false, false)] - [InlineData(true, true, true)] - public void IsOperationAcknowledged_should_return_expected_result( - bool withWriteConcern, - bool isAcknowledged, - bool expectedResult) + [InlineData(null, true)] + [InlineData(false, false)] + [InlineData(true, true)] + public void IsOperationAcknowledged_should_return_expected_result(bool? isAcknowledged, bool expectedResult) { - var operation = CreateOperation(withWriteConcern, isAcknowledged); + var writeConcern = isAcknowledged.HasValue ? (isAcknowledged.Value ? WriteConcern.Acknowledged : WriteConcern.Unacknowledged) : null; - var result = RetryableWriteOperationExecutorReflector.IsOperationAcknowledged(operation); + var result = RetryableWriteOperationExecutorReflector.IsOperationAcknowledged(writeConcern); result.Should().Be(expectedResult); } @@ -98,59 +91,35 @@ private IWriteBinding CreateBinding(bool areRetryableWritesSupported, bool hasSe var session = CreateSession(hasSessionId, isInTransaction); var channelSource = CreateChannelSource(areRetryableWritesSupported); mockBinding.SetupGet(m => m.Session).Returns(session); - mockBinding.Setup(m => m.GetWriteChannelSource(It.IsAny())).Returns(channelSource); + mockBinding.Setup(m => m.GetWriteChannelSource(It.IsAny(), It.IsAny>())).Returns(channelSource); return mockBinding.Object; } - private IChannelHandle CreateChannel(bool areRetryableWritesSupported) - { - var mockChannel = new Mock(); - var connectionDescription = CreateConnectionDescription(withLogicalSessionTimeout: areRetryableWritesSupported); - mockChannel.SetupGet(m => m.ConnectionDescription).Returns(connectionDescription); - return mockChannel.Object; - } - private IChannelSourceHandle CreateChannelSource(bool areRetryableWritesSupported) { var mockChannelSource = new Mock(); - var channel = CreateChannel(areRetryableWritesSupported); + var channel = Mock.Of(); mockChannelSource.Setup(m => m.GetChannel(It.IsAny())).Returns(channel); + mockChannelSource.Setup(m => m.ServerDescription).Returns(CreateServerDescription(withLogicalSessionTimeout: areRetryableWritesSupported)); return mockChannelSource.Object; } - private ConnectionDescription CreateConnectionDescription(bool withLogicalSessionTimeout, bool? serviceId = null) + private ServerDescription CreateServerDescription(bool withLogicalSessionTimeout, bool isLoadBalanced = false) { var clusterId = new ClusterId(1); var endPoint = new DnsEndPoint("localhost", 27017); var serverId = new ServerId(clusterId, endPoint); - var connectionId = new ConnectionId(serverId, 1); - var helloResultDocument = BsonDocument.Parse($"{{ ok : 1, maxWireVersion : {WireVersion.Server42} }}"); - if (withLogicalSessionTimeout) - { - helloResultDocument["logicalSessionTimeoutMinutes"] = 1; - helloResultDocument["msg"] = "isdbgrid"; // mongos - } - if (serviceId.HasValue) - { - helloResultDocument["serviceId"] = ObjectId.Empty; // load balancing mode - } - var helloResult = new HelloResult(helloResultDocument); - var connectionDescription = new ConnectionDescription(connectionId, helloResult); - return connectionDescription; + TimeSpan? logicalSessionTimeout = withLogicalSessionTimeout ? TimeSpan.FromMinutes(1) : null; + var serverType = isLoadBalanced ? ServerType.LoadBalanced : ServerType.ShardRouter; + + return new ServerDescription(serverId, endPoint, logicalSessionTimeout: logicalSessionTimeout, type: serverType); } private RetryableWriteContext CreateContext(bool retryRequested, bool areRetryableWritesSupported, bool hasSessionId, bool isInTransaction) { var binding = CreateBinding(areRetryableWritesSupported, hasSessionId, isInTransaction); - return RetryableWriteContext.Create(OperationContext.NoTimeout, binding, retryRequested); - } - - private IRetryableWriteOperation CreateOperation(bool withWriteConcern, bool isAcknowledged) - { - var mockOperation = new Mock>(); - var writeConcern = withWriteConcern ? (isAcknowledged ? WriteConcern.Acknowledged : WriteConcern.Unacknowledged) : null; - mockOperation.SetupGet(m => m.WriteConcern).Returns(writeConcern); - return mockOperation.Object; + var context = RetryableWriteContext.Create(OperationContext.NoTimeout, binding, retryRequested); + return context; } private ICoreSessionHandle CreateSession(bool hasSessionId, bool isInTransaction) @@ -165,28 +134,13 @@ private ICoreSessionHandle CreateSession(bool hasSessionId, bool isInTransaction // nested types internal static class RetryableWriteOperationExecutorReflector { - public static bool AreRetryableWritesSupported(ConnectionDescription connectionDescription) - { - return (bool)Reflector.InvokeStatic(typeof(RetryableWriteOperationExecutor), nameof(AreRetryableWritesSupported), connectionDescription); - } + public static bool AreRetryableWritesSupported(ServerDescription serverDescription) + => (bool)Reflector.InvokeStatic(typeof(RetryableWriteOperationExecutor), nameof(AreRetryableWritesSupported), serverDescription); - public static bool DoesContextAllowRetries(RetryableWriteContext context) => - (bool)Reflector.InvokeStatic(typeof(RetryableWriteOperationExecutor), nameof(DoesContextAllowRetries), context); + public static bool DoesContextAllowRetries(RetryableWriteContext context, ServerDescription server) + => (bool)Reflector.InvokeStatic(typeof(RetryableWriteOperationExecutor), nameof(DoesContextAllowRetries), context, server); - public static bool IsOperationAcknowledged(IRetryableWriteOperation operation) - { - var methodInfoDefinition = typeof(RetryableWriteOperationExecutor).GetMethods(BindingFlags.NonPublic | BindingFlags.Static) - .Where(m => m.Name == nameof(IsOperationAcknowledged)) - .Single(); - var methodInfo = methodInfoDefinition.MakeGenericMethod(typeof(BsonDocument)); - try - { - return (bool)methodInfo.Invoke(null, new object[] { operation }); - } - catch (TargetInvocationException exception) - { - throw exception.InnerException; - } - } + public static bool IsOperationAcknowledged(WriteConcern writeConcern) + => (bool)Reflector.InvokeStatic(typeof(RetryableWriteOperationExecutor), nameof(IsOperationAcknowledged), writeConcern); } } diff --git a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs index bc44255fa2c..93b655b3e19 100644 --- a/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Operations/WriteCommandOperationTests.cs @@ -1,4 +1,4 @@ -/* Copyright 2016-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * limitations under the License. */ +using System.Collections.Generic; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -216,6 +217,8 @@ private Mock CreateMockWriteBinding(IChannelSourceHandle channelS mockBinding.SetupGet(b => b.Session).Returns(mockSession.Object); mockBinding.Setup(b => b.GetWriteChannelSource(It.IsAny())).Returns(channelSource); mockBinding.Setup(b => b.GetWriteChannelSourceAsync(It.IsAny())).Returns(Task.FromResult(channelSource)); + mockBinding.Setup(b => b.GetWriteChannelSource(It.IsAny(), It.IsAny>())).Returns(channelSource); + mockBinding.Setup(b => b.GetWriteChannelSourceAsync(It.IsAny(), It.IsAny>())).Returns(Task.FromResult(channelSource)); return mockBinding; } diff --git a/tests/MongoDB.Driver.Tests/OperationContextTests.cs b/tests/MongoDB.Driver.Tests/OperationContextTests.cs index 49644802605..a3d00a95528 100644 --- a/tests/MongoDB.Driver.Tests/OperationContextTests.cs +++ b/tests/MongoDB.Driver.Tests/OperationContextTests.cs @@ -39,7 +39,7 @@ public void Constructor_should_initialize_properties() operationContext.Timeout.Should().Be(timeout); operationContext.RemainingTimeout.Should().Be(timeout); operationContext.CancellationToken.Should().Be(cancellationToken); - operationContext.ParentContext.Should().BeNull(); + operationContext.RootContext.Should().Be(operationContext); } [Fact] @@ -273,12 +273,23 @@ public void WithTimeout_should_calculate_proper_timeout(TimeSpan expected, TimeS ]; [Fact] - public void WithTimeout_should_set_ParentContext() + public void WithTimeout_should_set_RootContext() { var operationContext = new OperationContext(new Stopwatch(), Timeout.InfiniteTimeSpan, CancellationToken.None); var resultContext = operationContext.WithTimeout(TimeSpan.FromSeconds(10)); - resultContext.ParentContext.Should().Be(operationContext); + resultContext.RootContext.Should().Be(operationContext); + } + + [Fact] + public void WithTimeout_should_preserve_RootContext() + { + var rootContext = new OperationContext(new Stopwatch(), Timeout.InfiniteTimeSpan, CancellationToken.None); + + var intermediateContext = rootContext.WithTimeout(TimeSpan.FromSeconds(200)); + var resultContext = intermediateContext.WithTimeout(TimeSpan.FromSeconds(10)); + + resultContext.RootContext.Should().Be(rootContext); } [Fact]