5
5
import org .junit .platform .commons .support .ModifierSupport ;
6
6
import org .junit .platform .commons .support .ReflectionSupport ;
7
7
import org .junit .runner .Description ;
8
+ import org .junit .runners .model .FrameworkMethod ;
8
9
import org .junit .runners .model .MultipleFailureException ;
9
10
import org .testcontainers .lifecycle .Startable ;
10
11
import org .testcontainers .lifecycle .TestDescription ;
11
12
import org .testcontainers .lifecycle .TestLifecycleAware ;
12
13
14
+ import java .lang .reflect .AnnotatedElement ;
13
15
import java .lang .reflect .Field ;
16
+ import java .lang .reflect .Member ;
17
+ import java .lang .reflect .Method ;
14
18
import java .util .ArrayList ;
15
19
import java .util .Collections ;
16
20
import java .util .List ;
20
24
import java .util .function .Consumer ;
21
25
import java .util .function .Predicate ;
22
26
import java .util .stream .Collectors ;
27
+ import java .util .stream .Stream ;
23
28
24
29
/**
25
30
* Integrates Testcontainers with the JUnit4 lifecycle.
26
31
*/
27
32
public final class Testcontainers extends FailureDetectingExternalResource {
28
33
34
+ private static HierarchyTraversalMode TRAVERSAL_MODE = HierarchyTraversalMode .TOP_DOWN ;
35
+
29
36
private final Object testInstance ;
30
37
31
38
private List <Startable > startedContainers = Collections .emptyList ();
@@ -116,51 +123,82 @@ protected void finished(Description description) throws Exception {
116
123
}
117
124
118
125
private List <Startable > findContainers (Description description ) {
119
- if (description .getTestClass () == null ) {
126
+ Class <?> testClass = description .getTestClass ();
127
+ if (testClass == null ) {
120
128
return Collections .emptyList ();
121
129
}
122
- Predicate <Field > isTargetedContainer = isContainer ();
123
- if (testInstance == null ) {
124
- isTargetedContainer = isTargetedContainer .and (ModifierSupport ::isStatic );
125
- } else {
126
- isTargetedContainer = isTargetedContainer .and (ModifierSupport ::isNotStatic );
127
- }
128
130
129
- return ReflectionSupport
130
- .findFields (description .getTestClass (), isTargetedContainer , HierarchyTraversalMode .TOP_DOWN )
131
- .stream ()
132
- .map (this ::getContainerInstance )
131
+ Predicate <Member > hasExpectedModifier = testInstance == null
132
+ ? ModifierSupport ::isStatic
133
+ : ModifierSupport ::isNotStatic ;
134
+
135
+ return Stream
136
+ .of (
137
+ ReflectionSupport
138
+ .findMethods (testClass , isContainerMethod ().and (hasExpectedModifier ), TRAVERSAL_MODE )
139
+ .stream ()
140
+ .map (this ::getContainerInstance ),
141
+ ReflectionSupport
142
+ .findFields (testClass , isContainerField ().and (hasExpectedModifier ), TRAVERSAL_MODE )
143
+ .stream ()
144
+ .map (this ::getContainerInstance )
145
+ )
146
+ .flatMap (s -> s )
133
147
.collect (Collectors .toList ());
134
148
}
135
149
136
- private static Predicate <Field > isContainer () {
137
- return field -> {
138
- boolean isAnnotatedWithContainer = AnnotationSupport .isAnnotated (field , Container .class );
139
- if (isAnnotatedWithContainer ) {
140
- boolean isStartable = Startable .class .isAssignableFrom (field .getType ());
150
+ private static Predicate <Method > isContainerMethod () {
151
+ return method -> isAnnotatedWithContainer (method );
152
+ }
141
153
142
- if (!isStartable ) {
143
- throw new RuntimeException (
144
- String .format ("The @Container field '%s' does not implement Startable" , field .getName ())
145
- );
146
- }
147
- return true ;
148
- }
149
- return false ;
150
- };
154
+ private static Predicate <Field > isContainerField () {
155
+ return field -> isAnnotatedWithContainer (field );
156
+ }
157
+
158
+ private static boolean isAnnotatedWithContainer (AnnotatedElement element ) {
159
+ return AnnotationSupport .isAnnotated (element , Container .class );
160
+ }
161
+
162
+ private Startable getContainerInstance (Method method ) {
163
+ if (!Startable .class .isAssignableFrom (method .getReturnType ())) {
164
+ throw new RuntimeException (
165
+ String .format ("The @Container method '%s()' does not return a Startable" , method .getName ())
166
+ );
167
+ }
168
+
169
+ Object container = null ;
170
+ try {
171
+ method .setAccessible (true );
172
+ container = new FrameworkMethod (method ).invokeExplosively (testInstance );
173
+ } catch (Throwable e ) {
174
+ throwUnchecked (e );
175
+ }
176
+
177
+ if (container == null ) {
178
+ throw new RuntimeException (String .format ("The @Container method '%s()' returned null" , method .getName ()));
179
+ }
180
+ return (Startable ) container ;
151
181
}
152
182
153
183
private Startable getContainerInstance (Field field ) {
184
+ if (!Startable .class .isAssignableFrom (field .getType ())) {
185
+ throw new RuntimeException (
186
+ String .format ("The @Container field '%s' does not implement Startable" , field .getName ())
187
+ );
188
+ }
189
+
190
+ Startable container = null ;
154
191
try {
155
192
field .setAccessible (true );
156
- Startable containerInstance = (Startable ) field .get (testInstance );
157
- if (containerInstance == null ) {
158
- throw new RuntimeException ("Container " + field .getName () + " needs to be initialized" );
159
- }
160
- return containerInstance ;
193
+ container = (Startable ) field .get (testInstance );
161
194
} catch (IllegalAccessException e ) {
162
- throw new RuntimeException ("Cannot access container defined in field " + field .getName ());
195
+ throwUnchecked (e );
196
+ }
197
+
198
+ if (container == null ) {
199
+ throw new RuntimeException ("Container " + field .getName () + " needs to be initialized" );
163
200
}
201
+ return container ;
164
202
}
165
203
166
204
private static <T > void forEachReversed (List <T > list , Consumer <? super T > callback ) {
@@ -169,4 +207,9 @@ private static <T> void forEachReversed(List<T> list, Consumer<? super T> callba
169
207
callback .accept (iterator .previous ());
170
208
}
171
209
}
210
+
211
+ @ SuppressWarnings ("unchecked" )
212
+ private static <T extends Throwable > void throwUnchecked (Throwable e ) throws T {
213
+ throw (T ) e ;
214
+ }
172
215
}
0 commit comments