Skip to content

Commit 83f28d8

Browse files
authored
Make ZQuery#run reentrant safe (#499)
1 parent a345a27 commit 83f28d8

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

zio-query/shared/src/main/scala/zio/query/ZQuery.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,23 +539,35 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
539539
* Returns an effect that models executing this query with the specified
540540
* cache.
541541
*/
542-
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] =
542+
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = {
543+
import ZQuery.{currentCache, currentScope}
544+
545+
def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = {
546+
val oldValue = state.getFiberRefOrNull(fiberRef)
547+
state.setFiberRef(fiberRef, newValue)
548+
oldValue
549+
}
550+
551+
def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit =
552+
if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef)
553+
543554
asExitOrElse(null) match {
544555
case null =>
545556
ZIO.uninterruptibleMask { restore =>
546557
ZIO.withFiberRuntime[R, E, A] { (state, _) =>
547-
val scope = QueryScope.make()
548-
state.setFiberRef(ZQuery.currentCache, Some(cache))
549-
state.setFiberRef(ZQuery.currentScope, scope)
558+
val scope = QueryScope.make()
559+
val oldCache = setRef(state, currentCache, Some(cache))
560+
val oldScope = setRef(state, currentScope, scope)
550561
restore(runToZIO).exitWith { exit =>
551-
state.deleteFiberRef(ZQuery.currentCache)
552-
state.deleteFiberRef(ZQuery.currentScope)
562+
resetRef(state, currentCache, oldCache)
563+
resetRef(state, currentScope, oldScope)
553564
scope.closeAndExitWith(exit)
554565
}
555566
}
556567
}
557568
case exit => exit
558569
}
570+
}
559571

560572
/**
561573
* Returns an effect that models executing this query, returning the query

zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package zio.query
22

33
import zio._
44
import zio.query.QueryAspect._
5+
import zio.query.internal.QueryScope
56
import zio.test.Assertion._
67
import zio.test.TestAspect.{after, nonFlaky, silent}
78
import zio.test.{TestClock, TestConsole, TestEnvironment, _}
@@ -270,7 +271,7 @@ object ZQuerySpec extends ZIOBaseSpec {
270271
assert(log)(hasAt(0)(containsString("GetNameById(1)"))) &&
271272
assert(log)(hasAt(0)(containsString("GetNameById(2)"))) &&
272273
assert(log)(hasAt(1)(containsString("GetNameById(1)")))
273-
} @@ nonFlaky,
274+
} @@ nonFlaky(10),
274275
suite("race")(
275276
test("race with never") {
276277
val query = ZQuery.never.race(ZQuery.succeed(()))
@@ -370,6 +371,28 @@ object ZQuerySpec extends ZIOBaseSpec {
370371
value <- ref.get
371372
} yield assertTrue(value == 1, results.forall(_.isLeft))
372373
}
374+
),
375+
suite("run")(
376+
test("cache is reentrant safe") {
377+
val q =
378+
for {
379+
c1 <- ZQuery.fromZIO(ZQuery.currentCache.get)
380+
_ <- ZQuery.fromZIO(ZQuery.succeed("foo").run)
381+
c2 <- ZQuery.fromZIO(ZQuery.currentCache.get)
382+
} yield (c1, c2)
383+
384+
q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) }
385+
},
386+
test("scope is reentrant safe") {
387+
val q =
388+
for {
389+
c1 <- ZQuery.fromZIO(ZQuery.currentScope.get)
390+
_ <- ZQuery.fromZIO(ZQuery.succeed("foo").run)
391+
c2 <- ZQuery.fromZIO(ZQuery.currentScope.get)
392+
} yield (c1, c2)
393+
394+
q.run.map { case (c1, c2) => assertTrue(c1 != QueryScope.NoOp, c1 == c2) }
395+
}
373396
)
374397
) @@ silent
375398

0 commit comments

Comments
 (0)