Skip to content

fix: improve combined worker handling #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ import org.junit.jupiter.api.extension.AfterAllCallback
import org.junit.jupiter.api.extension.BeforeAllCallback
import org.junit.jupiter.api.extension.ExtensionContext
import org.junit.jupiter.api.extension.ExtensionContext.Namespace
import org.junit.jupiter.api.extension.ExtensionContext.Namespace.GLOBAL
import org.junit.jupiter.api.parallel.ResourceLock

/**
* Base class for extensions that wish to provide a Corda network
* throughout test suite execution
*/
abstract class AbstractCorda5Extension : BeforeAllCallback, AfterAllCallback, JupiterExtensionConfigSupport {

companion object {
private const val allCallbackCounterKey = "AbstractCorda5Extension#allCallbackCounterKey"
}

lateinit var config: Corda5NodesConfig
protected var started = false

abstract fun getNamespace(): Namespace

Expand All @@ -21,17 +26,33 @@ abstract class AbstractCorda5Extension : BeforeAllCallback, AfterAllCallback, Ju
): Corda5NodesConfig

/** Start the Corda network */
@ResourceLock(allCallbackCounterKey)
override fun beforeAll(extensionContext: ExtensionContext) {
config = getConfig(extensionContext)
started = true
if (config.combinedWorkerEnabled) {
val incrementedCallbackCount = extensionContext.root.getStore(GLOBAL).getOrDefault(allCallbackCounterKey, Int::class.java, 0)
.plus(1)
extensionContext.root.getStore(GLOBAL).put(allCallbackCounterKey, incrementedCallbackCount)
if (incrementedCallbackCount == 1) clearNodeHandles()
initNodeHandles()
}
}

/** Stop the Corda network */
@ResourceLock(allCallbackCounterKey)
override fun afterAll(extensionContext: ExtensionContext) {
if (config.combinedWorkerMode == CombinedWorkerMode.PER_CLASS)
clearNodeHandles()
// NO-OP
if (config.combinedWorkerEnabled) {
val decrementedCallbackCount = extensionContext.root.getStore(GLOBAL).get(allCallbackCounterKey, Int::class.java)
.minus(1)
extensionContext.root.getStore(GLOBAL).put(allCallbackCounterKey, decrementedCallbackCount)
if (decrementedCallbackCount == 0 || config.combinedWorkerMode == CombinedWorkerMode.PER_CLASS)
clearNodeHandles()
}
}

abstract fun initNodeHandles()

abstract fun clearNodeHandles()

abstract fun stopNodesNetwork()
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ data class Corda5NodesConfig(
).exists()
) {
currentDir = currentDir.parentFile
logger.info("currentDir: ${currentDir.absolutePath}")
}
logger.fine("Using Gradle module dir: ${currentDir.absolutePath}")
currentDir
}
}

val combinedWorkerEnabled = combinedWorkerMode != CombinedWorkerMode.NONE
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ class Corda5NodesExtension : AbstractCorda5Extension(), ParameterResolver {

private var nodeHandlesHelper: NodeHandlesHelper? = null

override fun beforeAll(extensionContext: ExtensionContext) {
super.beforeAll(extensionContext)
nodeHandlesHelper = NodeHandlesHelper(config)
}

override fun afterAll(extensionContext: ExtensionContext) {
super.afterAll(extensionContext)
}

override fun getConfig(
extensionContext: ExtensionContext
) = findConfig(getRequiredTestClass(extensionContext))

override fun initNodeHandles() {
nodeHandlesHelper = NodeHandlesHelper(config)
}

override fun clearNodeHandles() {
nodeHandlesHelper?.reset()
nodeHandlesHelper = null
}

override fun stopNodesNetwork() {
nodeHandlesHelper?.stop()
}

fun findConfig(testClass: Class<*>): Corda5NodesConfig =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ data class NodeHandle(
flowStatusResponse = flowsClient.flowStatus(holdingIdentityShortHash, clientRequestId)
when {
flowStatusResponse.isFinal() -> break
else -> logger.info("Non-final flow status will retry $flowStatusResponse")
else -> logger.fine("Non-final flow status, will retry: $flowStatusResponse")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,19 @@ class NodeHandlesHelper(
}
}

val nodeHandles: NodeHandles
get() {
when (config.combinedWorkerMode) {
CombinedWorkerMode.PER_CLASS ->
reset().also { nodeHandlesCache = buildNodeHandles() }
val nodeHandles: NodeHandles by lazy {
when (config.combinedWorkerMode) {
CombinedWorkerMode.PER_CLASS ->
reset().also { nodeHandlesCache = buildNodeHandles() }

CombinedWorkerMode.SHARED ->
if (nodeHandlesCache == null) nodeHandlesCache = buildNodeHandles()
CombinedWorkerMode.SHARED ->
if (nodeHandlesCache == null) nodeHandlesCache = buildNodeHandles()

CombinedWorkerMode.NONE ->
nodeHandlesCache = nodeHandles(nodesClient.nodes().virtualNodes)
}
return nodeHandlesCache!!
CombinedWorkerMode.NONE ->
nodeHandlesCache = nodeHandles(nodesClient.nodes().virtualNodes)
}
nodeHandlesCache!!
}

private val gradle by lazy {
GradleHelper(
Expand All @@ -56,6 +55,10 @@ class NodeHandlesHelper(

fun reset() {
nodeHandlesCache = null
stop()
}

fun stop() {
gradle.executeTaskAndWait("stopCorda")
}

Expand All @@ -71,7 +74,7 @@ class NodeHandlesHelper(
}
}

logger.info("Combined worker started, node list, size: ${nodesResponse!!.size}")
logger.fine("Combined worker started, node list, size: ${nodesResponse!!.size}")
if (nodesResponse.isEmpty()) {
gradle.executeTaskAndWait("5-vNodeSetup")
nodesResponse = virtualNodeInfos(::nodesEmptyResponseCheck)
Expand All @@ -90,11 +93,10 @@ class NodeHandlesHelper(
} catch (e: Exception) {
onError()
var maxWait = 2 * 60
logger.fine("Waiting for Combined Worker nodes...")
while (reloadCheck(nodesResponse) && maxWait > 0) {
maxWait -= 1
TimeUnit.SECONDS.sleep(1L)

logger.info("Waiting for Combined Worker nodes: $maxWait")
try {
nodesResponse = nodesClient.nodes().virtualNodes
} catch (e: Exception) {
Expand Down