diff --git a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/CallbackSupport.java b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/CallbackSupport.java index c7c6cf0e34bf..937b405f66fa 100644 --- a/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/CallbackSupport.java +++ b/junit-jupiter-engine/src/main/java/org/junit/jupiter/engine/descriptor/CallbackSupport.java @@ -8,6 +8,8 @@ * https://www.eclipse.org/legal/epl-v20.html */ +import java.util.Objects; +import java.util.List; package org.junit.jupiter.engine.descriptor; import static org.junit.platform.commons.util.CollectionUtils.forEachInReverseOrder; @@ -23,30 +25,50 @@ */ class CallbackSupport { - static void invokeBeforeCallbacks(Class type, JupiterEngineExecutionContext context, - CallbackInvoker callbackInvoker) { + private static void invokeCallbacks(List extensions, + ExtensionContext extensionContext, ThrowableCollector collector, CallbackInvoker invoker, boolean reverse, boolean breakOnPossibleException){ - ExtensionRegistry registry = context.getExtensionRegistry(); - ExtensionContext extensionContext = context.getExtensionContext(); - ThrowableCollector throwableCollector = context.getThrowableCollector(); - - for (T callback : registry.getExtensions(type)) { - throwableCollector.execute(() -> callbackInvoker.invoke(callback, extensionContext)); - if (throwableCollector.isNotEmpty()) { - break; + if(reverse){ + forEachInReverseOrder(extensions, ext -> collector.execute(() -> invoker.invoke(ext, extensionContext))); + }else{ + for(T ext: extensions){ + collector.execute(()-> invoker.invoke(ext, extensionContext)); + if (breakOnPossibleException && collector.isNotEmpty()) break; } } } - static void invokeAfterCallbacks(Class type, JupiterEngineExecutionContext context, + static void invokeBeforeCallbacks(Class type, JupiterEngineExecutionContext context, CallbackInvoker callbackInvoker) { + + Objects.requireNonNull(type, "type must not be null"); + Objects.requireNonNull(context, "context must not be null"); + Objects.requireNonNull(callbackInvoker, "callbackInvoker must not be null"); + + invokeCallbacks( + context.getExtensionRegistry().getExtensions(type), + context.getExtensionContext(), + context.getThrowableCollector(), + false, // forward order on callbacks + true //break out on any first exception encountered + ) + } - ExtensionRegistry registry = context.getExtensionRegistry(); - ExtensionContext extensionContext = context.getExtensionContext(); - ThrowableCollector throwableCollector = context.getThrowableCollector(); + static void invokeAfterCallbacks(Class type, JupiterEngineExecutionContext context, + CallbackInvoker callbackInvoker) { + + Objects.requireNonNull(type, "type must not be null"); + Objects.requireNonNull(context, "context must not be null"); + Objects.requireNonNull(callbackInvoker, "callbackInvoker must not be null"); - forEachInReverseOrder(registry.getExtensions(type), // - callback -> throwableCollector.execute(() -> callbackInvoker.invoke(callback, extensionContext))); + invokeCallbacks( + context.getExtensionRegistry().getExtensions(type), + context.getExtensionContext(), + context.getThrowableCollector(), + true, // reverse order on callbacks + false // allow all the callbacks to run. + ) + } @FunctionalInterface