Skip to content

Commit

Permalink
Re-implement KeyedMutex using the shared LockQueue & MapRef
Browse files Browse the repository at this point in the history
  • Loading branch information
BalmungSan committed Nov 23, 2024
1 parent 4036ced commit f351aec
Showing 1 changed file with 15 additions and 108 deletions.
123 changes: 15 additions & 108 deletions std/shared/src/main/scala/cats/effect/std/KeyedMutex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,134 +48,41 @@ abstract class KeyedMutex[F[_], K] {
}

object KeyedMutex {
private implicit val cellEq: Eq[LockQueue.Cell] = Eq.fromUniversalEquals

/**
* Creates a new `KeyedMutex`.
*/
def apply[F[_], K](implicit F: Concurrent[F]): F[KeyedMutex[F, K]] =
Ref
.of[F, Map[K, ConcurrentImpl.LockQueueCell]](
// Initialize the state with an empty Map.
Map.empty
MapRef[F, K, LockQueue.Cell].map { mapref =>
new ConcurrentImpl[F, K](
// Initialize the state with an already completed cell.
state = MapRef.defaultedMapRef(mapref, default = LockQueue.EmptyCell)
)
.map(state => new ConcurrentImpl[F, K](state))
}

/**
* Creates a new `KeyedMutex`. Like `apply` but initializes state using another effect
* constructor.
*/
def in[F[_], G[_], K](implicit F: Sync[F], G: Async[G]): F[KeyedMutex[G, K]] =
Ref
.in[F, G, Map[K, ConcurrentImpl.LockQueueCell]](
// Initialize the state with an empty Map.
Map.empty
MapRef.inConcurrentHashMap[F, G, K, LockQueue.Cell]().map { mapref =>
new ConcurrentImpl[G, K](
// Initialize the state with an already completed cell.
state = MapRef.defaultedMapRef(mapref, default = LockQueue.EmptyCell)
)
.map(state => new ConcurrentImpl[G, K](state))
}

private final class ConcurrentImpl[F[_], K](
state: Ref[F, Map[K, ConcurrentImpl.LockQueueCell]]
state: MapRef[F, K, LockQueue.Cell]
)(
implicit F: Concurrent[F]
) extends KeyedMutex[F, K] {

// This is a variant of the Craig, Landin, and Hagersten
// (CLH) queue lock for each key.
// Queue nodes (called cells below) are `Deferred`s,
// so fibers can suspend and wake up
// (instead of spinning, like in the original algorithm).

// Awakes whoever is waiting for us with the next cell in the queue.
private def awakeCell(
key: K,
ourCell: ConcurrentImpl.WaitingCell[F],
nextCell: ConcurrentImpl.LockQueueCell
): F[Unit] =
state.access.flatMap {
// If the current last cell in the queue for the given key is our cell,
// then that means nobody is waiting for us.
// Thus, we can just set the state to the next cell in the queue.
// Also, if the next cell is the empty one, we can just remove the key from the map.
// Otherwise, we awake whoever is waiting for us.
case (map, setter) =>
val lastCell = map(key) // Safe.
if (lastCell eq ourCell)
if (nextCell eq ConcurrentImpl.EmptyCell)
setter(map - key)
else
setter(map.updated(key, value = nextCell))
else F.pure(false)
} flatMap {
case false => ourCell.complete(nextCell).void
case true => F.unit
}

// Cancels a Fiber waiting for the Mutex.
private def cancel(
key: K,
ourCell: ConcurrentImpl.WaitingCell[F],
nextCell: ConcurrentImpl.LockQueueCell
): F[Unit] =
awakeCell(key, ourCell, nextCell)

// Acquires the Mutex.
private def acquire(key: K)(poll: Poll[F]): F[ConcurrentImpl.WaitingCell[F]] =
ConcurrentImpl.LockQueueCell[F].flatMap { ourCell =>
// Atomically get the last cell in the queue for the given key,
// and put ourselves as the last one.
state
.modify { map =>
val newMap = map.updated(key, value = ourCell)
val lastCell = map.getOrElse(key, default = ConcurrentImpl.EmptyCell)

newMap -> lastCell
}
.flatMap { lastCell =>
// Then we check what the next cell is.
// There are two options:
// + EmptyCell: Signaling that the mutex is free.
// + WaitingCell: Which means there is someone ahead of us in the queue.
// Thus, we wait for that cell to complete; and then check again.
//
// Only the waiting process is cancelable.
// If we are cancelled while waiting,
// we notify our waiter with the cell ahead of us.
def loop(
nextCell: ConcurrentImpl.LockQueueCell
): F[ConcurrentImpl.WaitingCell[F]] =
if (nextCell eq ConcurrentImpl.EmptyCell) F.pure(ourCell)
else {
F.onCancel(
poll(nextCell.asInstanceOf[ConcurrentImpl.WaitingCell[F]].get),
cancel(key, ourCell, nextCell)
).flatMap(loop)
}

loop(nextCell = lastCell)
}
}

// Releases the Mutex.
private def release(key: K)(ourCell: ConcurrentImpl.WaitingCell[F]): F[Unit] =
awakeCell(key, ourCell, nextCell = ConcurrentImpl.EmptyCell)

override def lock(key: K): Resource[F, Unit] =
Resource.makeFull[F, ConcurrentImpl.WaitingCell[F]](acquire(key))(release(key)).void
LockQueue.lock(queue = state(key))

override def mapK[G[_]](f: F ~> G)(implicit G: MonadCancel[G, _]): KeyedMutex[G, K] =
new KeyedMutex.TransformedKeyedMutex(this, f)
}

private object ConcurrentImpl {
// Represents a queue of waiters for the mutex.
private[KeyedMutex] final type LockQueueCell = AnyRef
// Represents the first cell of the queue.
private[KeyedMutex] final type EmptyCell = LockQueueCell
private[KeyedMutex] final val EmptyCell: EmptyCell = null
// Represents a waiting cell in the queue.
private[KeyedMutex] final type WaitingCell[F[_]] = Deferred[F, LockQueueCell]

private[KeyedMutex] def LockQueueCell[F[_]](implicit F: Concurrent[F]): F[WaitingCell[F]] =
Deferred[F, LockQueueCell]
new TransformedKeyedMutex(this, f)
}

private final class TransformedKeyedMutex[F[_], G[_], K](
Expand All @@ -187,6 +94,6 @@ object KeyedMutex {
underlying.lock(key).mapK(f)

override def mapK[H[_]](f: G ~> H)(implicit H: MonadCancel[H, _]): KeyedMutex[H, K] =
new KeyedMutex.TransformedKeyedMutex(this, f)
new TransformedKeyedMutex(this, f)
}
}

0 comments on commit f351aec

Please sign in to comment.