Skip to content

Commit e8bafa2

Browse files
authored
Add thread-safety checks option (#540)
1 parent 7f4ab36 commit e8bafa2

File tree

5 files changed

+146
-6
lines changed

5 files changed

+146
-6
lines changed

src/YesSql.Abstractions/IConfiguration.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,19 @@ public interface IConfiguration
6666
/// </summary>
6767
bool QueryGatingEnabled { get; set; }
6868

69+
/// <summary>
70+
/// Gets or sets whether the thread-safety checks are enabled.
71+
/// </summary>
72+
/// <remarks>
73+
/// When enabled, YesSql will throw an <see cref="InvalidOperationException" /> if two threads are trying to execute read or write
74+
/// operations on the database concurrently. This can help investigating thread-safety issue where an <see cref="ISession"/>
75+
/// instance is shared which is not supported.
76+
/// </remarks>
77+
/// <value>
78+
/// The default value is <see langword="false"/>.
79+
/// </value>
80+
public bool EnableThreadSafetyChecks { get; set; }
81+
6982
/// <summary>
7083
/// Gets the collection of types that must be checked for concurrency.
7184
/// </summary>

src/YesSql.Core/Configuration.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public Configuration()
2121
TablePrefix = "";
2222
CommandsPageSize = 500;
2323
QueryGatingEnabled = true;
24+
EnableThreadSafetyChecks = false;
2425
Logger = NullLogger.Instance;
2526
ConcurrentTypes = new HashSet<Type>();
2627
TableNameConvention = new DefaultTableNameConvention();
@@ -35,6 +36,7 @@ public Configuration()
3536
public string Schema { get; set; }
3637
public int CommandsPageSize { get; set; }
3738
public bool QueryGatingEnabled { get; set; }
39+
public bool EnableThreadSafetyChecks { get; set; }
3840
public IIdGenerator IdGenerator { get; set; }
3941
public ILogger Logger { get; set; }
4042
public HashSet<Type> ConcurrentTypes { get; }

src/YesSql.Core/Services/DefaultQuery.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,8 @@ public async Task<int> CountAsync()
11061106
var parameters = localBuilder.Parameters;
11071107
var key = new WorkerQueryKey(sql, localBuilder.Parameters);
11081108

1109+
_session.EnterAsyncExecution();
1110+
11091111
try
11101112
{
11111113
return await _session._store.ProduceAsync(key, static (state) =>
@@ -1127,6 +1129,10 @@ public async Task<int> CountAsync()
11271129

11281130
throw;
11291131
}
1132+
finally
1133+
{
1134+
_session.ExitAsyncExecution();
1135+
}
11301136
}
11311137

11321138
IQuery<T> IQuery.For<T>(bool filterType)
@@ -1206,6 +1212,8 @@ protected async Task<T> FirstOrDefaultImpl()
12061212

12071213
_query.Page(1, 0);
12081214

1215+
_query._session.EnterAsyncExecution();
1216+
12091217
try
12101218
{
12111219
if (typeof(IIndex).IsAssignableFrom(typeof(T)))
@@ -1259,6 +1267,10 @@ protected async Task<T> FirstOrDefaultImpl()
12591267
await _query._session.CancelAsync();
12601268
throw;
12611269
}
1270+
finally
1271+
{
1272+
_query._session.ExitAsyncExecution();
1273+
}
12621274
}
12631275

12641276
Task<IEnumerable<T>> IQuery<T>.ListAsync()
@@ -1297,6 +1309,8 @@ internal async Task<IEnumerable<T>> ListImpl()
12971309
}
12981310
}
12991311

1312+
_query._session.EnterAsyncExecution();
1313+
13001314
try
13011315
{
13021316
if (typeof(IIndex).IsAssignableFrom(typeof(T)))
@@ -1311,6 +1325,7 @@ internal async Task<IEnumerable<T>> ListImpl()
13111325

13121326
var sql = sqlBuilder.ToSqlString();
13131327
var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters);
1328+
13141329
return await _query._session._store.ProduceAsync(key, static (state) =>
13151330
{
13161331
var logger = state.Query._session._store.Configuration.Logger;
@@ -1365,6 +1380,10 @@ internal async Task<IEnumerable<T>> ListImpl()
13651380

13661381
throw;
13671382
}
1383+
finally
1384+
{
1385+
_query._session.ExitAsyncExecution();
1386+
}
13681387
}
13691388

13701389
private string GetDeduplicatedQuery()

src/YesSql.Core/Session.cs

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Data.Common;
77
using System.Linq;
88
using System.Linq.Expressions;
9+
using System.Threading;
910
using System.Threading.Tasks;
1011
using YesSql.Commands;
1112
using YesSql.Data;
@@ -35,6 +36,10 @@ public class Session : ISession
3536
private readonly ILogger _logger;
3637
private readonly bool _withTracking;
3738

39+
private readonly bool _enableThreadSafetyChecks;
40+
private int _asyncOperations = 0;
41+
private string _previousStackTrace = null;
42+
3843
public Session(Store store, bool withTracking = true)
3944
{
4045
_store = store;
@@ -43,6 +48,7 @@ public Session(Store store, bool withTracking = true)
4348
_logger = store.Configuration.Logger;
4449
_withTracking = withTracking;
4550
_defaultState = new SessionState();
51+
_enableThreadSafetyChecks = _store.Configuration.EnableThreadSafetyChecks;
4652
_collectionStates = new Dictionary<string, SessionState>()
4753
{
4854
[string.Empty] = _defaultState
@@ -500,7 +506,7 @@ public async Task<IEnumerable<T>> GetAsync<T>(long[] ids, string collection = nu
500506
var key = new WorkerQueryKey(nameof(GetAsync), ids);
501507
try
502508
{
503-
var documents = await _store.ProduceAsync(key, (state) =>
509+
var documents = await _store.ProduceAsync(key, static (state) =>
504510
{
505511
var logger = state.Store.Configuration.Logger;
506512

@@ -661,7 +667,12 @@ public void Dispose()
661667
GC.SuppressFinalize(this);
662668
}
663669

664-
public async Task FlushAsync()
670+
public Task FlushAsync()
671+
{
672+
return FlushInternalAsync(false);
673+
}
674+
675+
private async Task FlushInternalAsync(bool saving)
665676
{
666677
if (!HasWork())
667678
{
@@ -684,6 +695,12 @@ public async Task FlushAsync()
684695

685696
CheckDisposed();
686697

698+
// Only check thread-safety if not called from SaveChangesAsync
699+
if (!saving)
700+
{
701+
EnterAsyncExecution();
702+
}
703+
687704
try
688705
{
689706
// saving all tracked entities
@@ -767,6 +784,12 @@ public async Task FlushAsync()
767784

768785
_commands?.Clear();
769786
_flushing = false;
787+
788+
// Only check thread-safety if not called from SaveChangesAsync
789+
if (!saving)
790+
{
791+
ExitAsyncExecution();
792+
}
770793
}
771794
}
772795

@@ -856,20 +879,48 @@ private void BatchCommands()
856879
_commands.AddRange(batches);
857880
}
858881

882+
public void EnterAsyncExecution()
883+
{
884+
if (!_enableThreadSafetyChecks)
885+
{
886+
return;
887+
}
888+
889+
if (Interlocked.Increment(ref _asyncOperations) > 1)
890+
{
891+
throw new InvalidOperationException($"Two concurrent threads have been detected accessing the same ISession instance from: \n{Environment.StackTrace}\nand:\n{_previousStackTrace}\n---");
892+
}
893+
894+
_previousStackTrace = Environment.StackTrace;
895+
}
896+
897+
public void ExitAsyncExecution()
898+
{
899+
if (!_enableThreadSafetyChecks)
900+
{
901+
return;
902+
}
903+
904+
Interlocked.Decrement(ref _asyncOperations);
905+
}
906+
859907
public async Task SaveChangesAsync()
860908
{
909+
EnterAsyncExecution();
910+
861911
try
862912
{
863913
if (!_cancel)
864914
{
865-
await FlushAsync();
915+
await FlushInternalAsync(true);
866916

867917
_save = true;
868918
}
869919
}
870920
finally
871921
{
872922
await CommitOrRollbackTransactionAsync();
923+
ExitAsyncExecution();
873924
}
874925
}
875926

@@ -1362,11 +1413,20 @@ public async Task<DbTransaction> BeginTransactionAsync(IsolationLevel isolationL
13621413

13631414
public Task CancelAsync()
13641415
{
1365-
CheckDisposed();
1416+
EnterAsyncExecution();
1417+
1418+
try
1419+
{
1420+
CheckDisposed();
13661421

1367-
_cancel = true;
1422+
_cancel = true;
13681423

1369-
return ReleaseTransactionAsync();
1424+
return ReleaseTransactionAsync();
1425+
}
1426+
finally
1427+
{
1428+
ExitAsyncExecution();
1429+
}
13701430
}
13711431

13721432
public IStore Store => _store;

test/YesSql.Tests/CoreTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Data.Common;
66
using System.Diagnostics;
77
using System.Linq;
8+
using System.Security.Cryptography;
89
using System.Threading;
910
using System.Threading.Tasks;
1011
using Xunit;
@@ -6401,6 +6402,51 @@ await session.SaveAsync(new Article
64016402
Assert.Equal(10, result);
64026403
}
64036404

6405+
[Fact]
6406+
public virtual async Task ShouldDetectThreadSafetyIssues()
6407+
{
6408+
try
6409+
{
6410+
_store.Configuration.EnableThreadSafetyChecks = true;
6411+
6412+
await using var session = _store.CreateSession();
6413+
6414+
_store.Configuration.EnableThreadSafetyChecks = false;
6415+
6416+
var person = new Person { Firstname = "Bill" };
6417+
await session.SaveAsync(person);
6418+
await session.SaveChangesAsync();
6419+
6420+
Task[] tasks = null;
6421+
6422+
var throws = Assert.ThrowsAsync<InvalidOperationException>(async () =>
6423+
{
6424+
tasks = Enumerable.Range(0, 10).Select(x => Task.Run(DoWork)).ToArray();
6425+
await Task.WhenAll(tasks);
6426+
});
6427+
6428+
var result = Task.WaitAny(throws, Task.Delay(5000));
6429+
6430+
Assert.Equal(0, result);
6431+
6432+
async Task DoWork()
6433+
{
6434+
while (true)
6435+
{
6436+
var p = await session.Query<Person>().FirstOrDefaultAsync();
6437+
Assert.NotNull(p);
6438+
6439+
person.Firstname = "Bill" + RandomNumberGenerator.GetInt32(100);
6440+
await session.FlushAsync();
6441+
}
6442+
}
6443+
}
6444+
finally
6445+
{
6446+
_store.Configuration.EnableThreadSafetyChecks = false;
6447+
}
6448+
}
6449+
64046450
#region FilterTests
64056451

64066452
[Fact]

0 commit comments

Comments
 (0)