13
13
import org .springframework .jdbc .core .JdbcTemplate ;
14
14
import org .springframework .orm .jpa .LocalContainerEntityManagerFactoryBean ;
15
15
16
- import javax .persistence .Entity ;
17
16
import javax .persistence .Table ;
18
- import javax .persistence .spi .PersistenceUnitInfo ;
19
17
import javax .sql .DataSource ;
20
18
import java .lang .reflect .Type ;
21
19
import java .lang .reflect .TypeVariable ;
22
- import java .util .*;
23
- import java .util .function .Function ;
20
+ import java .util .List ;
21
+ import java .util .Map ;
22
+ import java .util .Optional ;
24
23
import java .util .stream .Collectors ;
25
24
import java .util .stream .Stream ;
26
25
import java .util .stream .StreamSupport ;
@@ -44,7 +43,8 @@ public class SpringJPASteps {
44
43
45
44
@ Autowired (required = false )
46
45
private List <LocalContainerEntityManagerFactoryBean > entityManagerFactories ;
47
- private Map <String , Class <?>> entityClassByTableName ;
46
+ private Map <Type , CrudRepository <?, ?>> crudRepositoryByClass ;
47
+ private Map <String , Type > entityClassByTableName ;
48
48
49
49
private boolean disableTriggers = true ;
50
50
private final ObjectSteps objects ;
@@ -64,15 +64,41 @@ public void before() {
64
64
});
65
65
}
66
66
67
+ if (crudRepositoryByClass == null ) {
68
+ crudRepositoryByClass = spring .applicationContext ().getBeansOfType (CrudRepository .class ).values ()
69
+ .stream ()
70
+ .<Map .Entry <Class <?>, CrudRepository <?, ?>>>mapMulti ((crudRepository , consumer ) -> {
71
+ Map <TypeVariable <?>, Type > typeArguments = TypeUtils .getTypeArguments (crudRepository .getClass (), CrudRepository .class );
72
+ Type type = typeArguments .get (CrudRepository .class .getTypeParameters ()[0 ]);
73
+
74
+ if (type instanceof TypeVariable <?> typeVariable ) {
75
+ type = typeVariable .getBounds ()[0 ];
76
+ TypeParser .getSubtypesOf ((Class <?>) type )
77
+ .forEach (clazz -> consumer .accept (Map .entry (clazz , crudRepository )));
78
+ }
79
+
80
+ if (type instanceof Class <?> clazz ) consumer .accept (Map .entry (clazz , crudRepository ));
81
+ })
82
+ .collect (Collectors .toMap (
83
+ Map .Entry ::getKey ,
84
+ Map .Entry ::getValue
85
+ ));
86
+ }
67
87
if (entityClassByTableName == null ) {
68
- entityClassByTableName = Optional .ofNullable (entityManagerFactories ).orElse (Collections .emptyList ()).stream ()
69
- .map (LocalContainerEntityManagerFactoryBean ::getPersistenceUnitInfo )
70
- .map (PersistenceUnitInfo ::getManagedClassNames )
71
- .flatMap (Collection ::stream )
72
- .map (TypeParser ::parse )
88
+ entityClassByTableName = crudRepositoryByClass .keySet ().stream ()
73
89
.map (type -> (Class <?>) type )
74
- .filter (entityClass -> entityClass .isAnnotationPresent (Table .class ))
75
- .collect (Collectors .toMap (entityClass -> entityClass .getAnnotation (Table .class ).name (), Function .identity ()));
90
+ .sorted ((c1 , c2 ) -> {
91
+ if (TypeParser .defaultPackage == null ) return 0 ;
92
+
93
+ if (c1 .getPackageName ().startsWith (TypeParser .defaultPackage )) return -1 ;
94
+
95
+ return c2 .getPackageName ().startsWith (TypeParser .defaultPackage ) ? 1 : 0 ;
96
+ })
97
+ .<Map .Entry <String , Class <?>>>mapMulti ((clazz , consumer ) -> {
98
+ String tableName = Optional .ofNullable (clazz .getAnnotation (Table .class )).map (Table ::name )
99
+ .orElse (Optional .ofNullable (clazz .getAnnotation (org .springframework .data .relational .core .mapping .Table .class )).map (org .springframework .data .relational .core .mapping .Table ::value ).orElse (null ));
100
+ if (tableName != null ) consumer .accept (Map .entry (tableName , clazz ));
101
+ }).collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue , (t1 , t2 ) -> t1 ));
76
102
}
77
103
}
78
104
@@ -166,7 +192,7 @@ public <E> void the_repository_will_contain_with_type(Guard guard, CrudRepositor
166
192
.filter (entityManagerFactory -> entityManagerFactory .getPersistenceUnitInfo () != null )
167
193
.filter (entityManagerFactory -> entityManagerFactory .getPersistenceUnitInfo ().getManagedClassNames ().contains (entityClass .getName ()))
168
194
.map (LocalContainerEntityManagerFactoryBean ::getDataSource ).findFirst ()
169
- .orElseThrow ( );
195
+ .orElse ( entityManagerFactories . get ( 0 ). getDataSource () );
170
196
new JdbcTemplate (dataSource ).update ("TRUNCATE %s RESTART IDENTITY CASCADE" .formatted (table ));
171
197
}
172
198
repository .saveAll (Mapper .readAsAListOf (entities , entityClass ));
@@ -190,43 +216,20 @@ public <E> void add_repository_content_to_variable(Guard guard, String name, Cru
190
216
}
191
217
192
218
public <E > CrudRepository <E , ?> getRepositoryForTable (String table ) {
193
- return getRepositoryForEntity (entityTypeByTableNameOrClassName ( table ));
219
+ return getRepositoryForEntity (Optional . ofNullable ( this . entityClassByTableName . get ( table )). orElseGet (() -> TypeParser . parse ( table ) ));
194
220
}
195
221
196
222
@ Nullable
197
223
private Type entityTypeByTableNameOrClassName (String entityTableOrClass ) {
198
- return entityClassByTableName . containsKey ( entityTableOrClass ) ? entityClassByTableName .get (entityTableOrClass ) : TypeParser .parse (entityTableOrClass );
224
+ return Optional . ofNullable ( entityClassByTableName .get (entityTableOrClass )). orElseGet (() -> TypeParser .parse (entityTableOrClass ) );
199
225
}
200
226
201
227
@ SuppressWarnings ({"unchecked" })
202
228
public <E > CrudRepository <E , ?> getRepositoryForEntity (Type type ) {
203
- if (Types .rawTypeOf (type ).isAnnotationPresent (Entity .class )) {
204
- return spring .applicationContext ().getBeansOfType (CrudRepository .class ).values ()
205
- .stream ()
206
- .map (bean -> (CrudRepository <E , ?>) bean )
207
- .filter (r -> {
208
- Map <TypeVariable <?>, Type > typeArguments = TypeUtils .getTypeArguments (r .getClass (), CrudRepository .class );
209
- if (type .equals (TypeUtils .unrollVariables (typeArguments , CrudRepository .class .getTypeParameters ()[0 ])))
210
- return true ;
211
-
212
- if (type instanceof Class <?> clazz
213
- && typeArguments .get (CrudRepository .class .getTypeParameters ()[0 ]) instanceof TypeVariable <?> typeVariable ) {
214
- Type handledEntitySuperclass = typeVariable .getBounds ()[0 ];
215
-
216
- List <Class <?>> superClasses = new ArrayList <>();
217
- while (clazz != Object .class ) {
218
- superClasses .add (clazz );
219
- clazz = clazz .getSuperclass ();
220
- }
221
-
222
- return superClasses .contains (handledEntitySuperclass );
223
- }
229
+ CrudRepository <?, ?> crudRepository = crudRepositoryByClass .get (type );
230
+ if (crudRepository == null ) throw new AssertionError (type + " is not an Entity!" );
224
231
225
- return false ;
226
- }).sorted ((r1 , r2 ) -> TypeUtils .getTypeArguments (r1 .getClass (), CrudRepository .class ).values ().stream ().findFirst ().orElse (null ) instanceof TypeVariable <?> typeVariable ? 1 : 0 )
227
- .findFirst ().orElseThrow (() -> new AssertionError ("there was no CrudRepository found for entity %s! If you don't need one in your app, you must create one in your tests!" .formatted (type .getTypeName ())));
228
- }
229
- throw new AssertionError (type + " is not an Entity!" );
232
+ return (CrudRepository <E , ?>) crudRepository ;
230
233
}
231
234
232
235
public <E > CrudRepository <E , ?> getRepositoryByType (Type type ) {
0 commit comments