7
7
import io .cucumber .java .en .Then ;
8
8
import org .apache .commons .lang3 .reflect .TypeUtils ;
9
9
import org .jetbrains .annotations .NotNull ;
10
+ import org .jetbrains .annotations .Nullable ;
10
11
import org .springframework .beans .factory .annotation .Autowired ;
11
12
import org .springframework .data .repository .CrudRepository ;
12
13
import org .springframework .jdbc .core .JdbcTemplate ;
13
14
import org .springframework .orm .jpa .LocalContainerEntityManagerFactoryBean ;
14
15
15
16
import javax .persistence .Entity ;
16
17
import javax .persistence .Table ;
18
+ import javax .persistence .spi .PersistenceUnitInfo ;
17
19
import javax .sql .DataSource ;
18
20
import java .lang .reflect .Type ;
21
+ import java .lang .reflect .TypeVariable ;
22
+ import java .util .ArrayList ;
23
+ import java .util .Collection ;
19
24
import java .util .List ;
20
25
import java .util .Map ;
26
+ import java .util .function .Function ;
21
27
import java .util .stream .Collectors ;
22
28
import java .util .stream .Stream ;
23
29
import java .util .stream .StreamSupport ;
@@ -41,6 +47,7 @@ public class SpringJPASteps {
41
47
42
48
@ Autowired (required = false )
43
49
private List <LocalContainerEntityManagerFactoryBean > entityManagerFactories ;
50
+ private Map <String , Class <?>> entityClassByTableName ;
44
51
45
52
private boolean disableTriggers = true ;
46
53
private final ObjectSteps objects ;
@@ -59,6 +66,17 @@ public void before() {
59
66
DatabaseCleaner .setTriggers (dataSource , schemaToClean , DatabaseCleaner .TriggerStatus .enable );
60
67
});
61
68
}
69
+
70
+ if (entityClassByTableName == null ) {
71
+ entityClassByTableName = entityManagerFactories .stream ()
72
+ .map (LocalContainerEntityManagerFactoryBean ::getPersistenceUnitInfo )
73
+ .map (PersistenceUnitInfo ::getManagedClassNames )
74
+ .flatMap (Collection ::stream )
75
+ .map (TypeParser ::parse )
76
+ .map (type -> (Class <?>) type )
77
+ .filter (entityClass -> entityClass .isAnnotationPresent (Table .class ))
78
+ .collect (Collectors .toMap (entityClass -> entityClass .getAnnotation (Table .class ).name (), Function .identity ()));
79
+ }
62
80
}
63
81
64
82
@ NotNull
@@ -76,12 +94,12 @@ public void the_repository_will_contain(Guard guard, Type repositoryType, Insert
76
94
77
95
@ Given (THAT + GUARD + "the ([^ ]+) table will contain" + INSERTION_MODE + ":$" )
78
96
public void the_table_will_contain (Guard guard , String table , InsertionMode insertionMode , Object content ) {
79
- the_repository_will_contain (guard , getRepositoryForTable (table ), insertionMode , objects .resolve (content ));
97
+ the_repository_will_contain_with_type (guard , getRepositoryForTable (table ), insertionMode , entityTypeByTableNameOrClassName ( table ) , objects .resolve (content ));
80
98
}
81
99
82
100
@ Given (THAT + GUARD + "the " + TYPE + " entities will contain" + INSERTION_MODE + ":$" )
83
101
public void the_entities_will_contain (Guard guard , Type type , InsertionMode insertionMode , Object content ) {
84
- the_repository_will_contain (guard , getRepositoryForEntity (type ), insertionMode , objects .resolve (content ));
102
+ the_repository_will_contain_with_type (guard , getRepositoryForEntity (type ), insertionMode , type , objects .resolve (content ));
85
103
}
86
104
87
105
@ Given (THAT + GUARD + "the triggers are (enable|disable)d$" )
@@ -134,21 +152,27 @@ private void the_repository_contains_nothing(Guard guard, CrudRepository<Object,
134
152
}
135
153
136
154
public <E > void the_repository_will_contain (Guard guard , CrudRepository <E , ?> repository , InsertionMode insertionMode , String entities ) {
155
+ the_repository_will_contain_with_type (guard , repository , insertionMode , getEntityType (repository ), entities );
156
+ }
157
+
158
+ public <E > void the_repository_will_contain_with_type (Guard guard , CrudRepository <E , ?> repository , InsertionMode insertionMode , Type entityType , String entities ) {
137
159
guard .in (objects , () -> {
160
+ if (!(entityType instanceof Class <?>)) return ;
161
+
162
+ Class <E > entityClass = (Class <E >) entityType ;
138
163
if (disableTriggers ) {
139
164
dataSources ().forEach (dataSource -> DatabaseCleaner .setTriggers (dataSource , schemaToClean , DatabaseCleaner .TriggerStatus .disable ));
140
165
}
141
- Class <E > entityType = getEntityType (repository );
142
166
if (insertionMode == InsertionMode .ONLY ) {
143
- String table = entityType .getAnnotation (Table .class ).name ();
167
+ String table = entityClass .getAnnotation (Table .class ).name ();
144
168
DataSource dataSource = entityManagerFactories .stream ()
145
169
.filter (entityManagerFactory -> entityManagerFactory .getPersistenceUnitInfo () != null )
146
- .filter (entityManagerFactory -> entityManagerFactory .getPersistenceUnitInfo ().getManagedClassNames ().contains (entityType .getName ()))
170
+ .filter (entityManagerFactory -> entityManagerFactory .getPersistenceUnitInfo ().getManagedClassNames ().contains (entityClass .getName ()))
147
171
.map (LocalContainerEntityManagerFactoryBean ::getDataSource ).findFirst ()
148
172
.orElseThrow ();
149
173
new JdbcTemplate (dataSource ).update ("TRUNCATE %s RESTART IDENTITY CASCADE" .formatted (table ));
150
174
}
151
- repository .saveAll (Mapper .readAsAListOf (entities , entityType ));
175
+ repository .saveAll (Mapper .readAsAListOf (entities , entityClass ));
152
176
if (disableTriggers ) {
153
177
dataSources ().forEach (dataSource -> DatabaseCleaner .setTriggers (dataSource , schemaToClean , DatabaseCleaner .TriggerStatus .enable ));
154
178
}
@@ -168,20 +192,13 @@ public <E> void add_repository_content_to_variable(Guard guard, String name, Cru
168
192
guard .in (objects , () -> objects .add (name , StreamSupport .stream (repository .findAll ().spliterator (), false ).toList ()));
169
193
}
170
194
171
- @ SuppressWarnings ("unchecked" )
172
195
public <E > CrudRepository <E , ?> getRepositoryForTable (String table ) {
173
- return spring .applicationContext ().getBeansOfType (CrudRepository .class ).values ()
174
- .stream ()
175
- .map (bean -> (CrudRepository <E , ?>) bean )
176
- .filter (r -> {
177
- Class <E > e = getEntityType (r );
178
- return e != null && (
179
- (e .isAnnotationPresent (Table .class ) && e .getAnnotation (Table .class ).name ().equals (table ))
180
- || e .getSimpleName ().equals (table )
181
- || toSnakeCase (e .getSimpleName ()).equals (table ));
182
- }).findFirst ().orElseThrow (() -> new AssertionError (
183
- "there was no CrudRepository found for table '%s'! If you don't need one in your app, you must create one in your tests!" .formatted (table )
184
- ));
196
+ return getRepositoryForEntity (entityTypeByTableNameOrClassName (table ));
197
+ }
198
+
199
+ @ Nullable
200
+ private Type entityTypeByTableNameOrClassName (String entityTableOrClass ) {
201
+ return entityClassByTableName .containsKey (entityTableOrClass ) ? entityClassByTableName .get (entityTableOrClass ) : TypeParser .parse (entityTableOrClass );
185
202
}
186
203
187
204
@ SuppressWarnings ({"unchecked" })
@@ -190,7 +207,26 @@ public <E> void add_repository_content_to_variable(Guard guard, String name, Cru
190
207
return spring .applicationContext ().getBeansOfType (CrudRepository .class ).values ()
191
208
.stream ()
192
209
.map (bean -> (CrudRepository <E , ?>) bean )
193
- .filter (r -> type .equals (TypeUtils .unrollVariables (TypeUtils .getTypeArguments (r .getClass (), CrudRepository .class ), CrudRepository .class .getTypeParameters ()[0 ])))
210
+ .filter (r -> {
211
+ Map <TypeVariable <?>, Type > typeArguments = TypeUtils .getTypeArguments (r .getClass (), CrudRepository .class );
212
+ if (type .equals (TypeUtils .unrollVariables (typeArguments , CrudRepository .class .getTypeParameters ()[0 ])))
213
+ return true ;
214
+
215
+ if (type instanceof Class <?> clazz
216
+ && typeArguments .get (CrudRepository .class .getTypeParameters ()[0 ]) instanceof TypeVariable <?> typeVariable ) {
217
+ Type handledEntitySuperclass = typeVariable .getBounds ()[0 ];
218
+
219
+ List <Class <?>> superClasses = new ArrayList <>();
220
+ while (clazz != Object .class ) {
221
+ superClasses .add (clazz );
222
+ clazz = clazz .getSuperclass ();
223
+ }
224
+
225
+ return superClasses .contains (handledEntitySuperclass );
226
+ }
227
+
228
+ return false ;
229
+ }).sorted ((r1 , r2 ) -> TypeUtils .getTypeArguments (r1 .getClass (), CrudRepository .class ).values ().stream ().findFirst ().orElse (null ) instanceof TypeVariable <?> typeVariable ? 1 : 0 )
194
230
.findFirst ().orElseThrow (() -> new AssertionError ("there was no CrudRepository found for entity %s! If you don't need one in your app, you must create one in your tests!" .formatted (type .getTypeName ())));
195
231
}
196
232
throw new AssertionError (type + " is not an Entity!" );
0 commit comments