16
16
17
17
package dagger .hilt .android .internal .lifecycle ;
18
18
19
+ import static androidx .lifecycle .SavedStateHandleSupport .createSavedStateHandle ;
20
+
19
21
import android .app .Activity ;
20
22
import android .os .Bundle ;
21
23
import androidx .annotation .NonNull ;
22
24
import androidx .annotation .Nullable ;
23
- import androidx .lifecycle .AbstractSavedStateViewModelFactory ;
24
- import androidx .lifecycle .SavedStateHandle ;
25
25
import androidx .lifecycle .ViewModel ;
26
26
import androidx .lifecycle .ViewModelProvider ;
27
27
import androidx .lifecycle .viewmodel .CreationExtras ;
37
37
import java .util .Map ;
38
38
import java .util .Set ;
39
39
import javax .inject .Provider ;
40
+ import kotlin .jvm .functions .Function1 ;
40
41
41
42
/**
42
43
* View Model Provider Factory for the Hilt Extension.
@@ -55,20 +56,32 @@ public final class HiltViewModelFactory implements ViewModelProvider.Factory {
55
56
public interface ViewModelFactoriesEntryPoint {
56
57
@ HiltViewModelMap
57
58
Map <String , Provider <ViewModel >> getHiltViewModelMap ();
59
+
60
+ // From ViewModel class names to user defined @AssistedFactory-annotated implementations.
61
+ @ HiltViewModelAssistedMap
62
+ Map <String , Object > getHiltViewModelAssistedMap ();
58
63
}
59
64
65
+ /** Creation extra key for the callbacks that create @AssistedInject-annotated ViewModels. */
66
+ public static final CreationExtras .Key <Function1 <Object , ViewModel >> CREATION_CALLBACK_KEY =
67
+ new CreationExtras .Key <Function1 <Object , ViewModel >>() {};
68
+
60
69
/** Hilt module for providing the empty multi-binding map of ViewModels. */
61
70
@ Module
62
71
@ InstallIn (ViewModelComponent .class )
63
72
interface ViewModelModule {
64
73
@ Multibinds
65
74
@ HiltViewModelMap
66
75
Map <String , ViewModel > hiltViewModelMap ();
76
+
77
+ @ Multibinds
78
+ @ HiltViewModelAssistedMap
79
+ Map <String , Object > hiltViewModelAssistedMap ();
67
80
}
68
81
69
82
private final Set <String > hiltViewModelKeys ;
70
83
private final ViewModelProvider .Factory delegateFactory ;
71
- private final AbstractSavedStateViewModelFactory hiltViewModelFactory ;
84
+ private final ViewModelProvider . Factory hiltViewModelFactory ;
72
85
73
86
public HiltViewModelFactory (
74
87
@ NonNull Set <String > hiltViewModelKeys ,
@@ -77,31 +90,75 @@ public HiltViewModelFactory(
77
90
this .hiltViewModelKeys = hiltViewModelKeys ;
78
91
this .delegateFactory = delegateFactory ;
79
92
this .hiltViewModelFactory =
80
- new AbstractSavedStateViewModelFactory () {
93
+ new ViewModelProvider . Factory () {
81
94
@ NonNull
82
95
@ Override
83
- @ SuppressWarnings ("unchecked" )
84
- protected <T extends ViewModel > T create (
85
- @ NonNull String key , @ NonNull Class <T > modelClass , @ NonNull SavedStateHandle handle ) {
96
+ public <T extends ViewModel > T create (
97
+ @ NonNull Class <T > modelClass , @ NonNull CreationExtras extras ) {
86
98
RetainedLifecycleImpl lifecycle = new RetainedLifecycleImpl ();
87
- ViewModelComponent component = viewModelComponentBuilder
88
- .savedStateHandle (handle )
89
- .viewModelLifecycle (lifecycle )
90
- .build ();
99
+ ViewModelComponent component =
100
+ viewModelComponentBuilder
101
+ .savedStateHandle (createSavedStateHandle (extras ))
102
+ .viewModelLifecycle (lifecycle )
103
+ .build ();
104
+ T viewModel = createViewModel (component , modelClass , extras );
105
+ viewModel .addCloseable (lifecycle ::dispatchOnCleared );
106
+ return viewModel ;
107
+ }
108
+
109
+ private <T extends ViewModel > T createViewModel (
110
+ @ NonNull ViewModelComponent component ,
111
+ @ NonNull Class <T > modelClass ,
112
+ @ NonNull CreationExtras extras ) {
91
113
Provider <? extends ViewModel > provider =
92
114
EntryPoints .get (component , ViewModelFactoriesEntryPoint .class )
93
115
.getHiltViewModelMap ()
94
116
.get (modelClass .getName ());
95
- if (provider == null ) {
96
- throw new IllegalStateException (
97
- "Expected the @HiltViewModel-annotated class '"
98
- + modelClass .getName ()
99
- + "' to be available in the multi-binding of "
100
- + "@HiltViewModelMap but none was found." );
117
+ Function1 <Object , ViewModel > creationCallback = extras .get (CREATION_CALLBACK_KEY );
118
+ Object assistedFactory =
119
+ EntryPoints .get (component , ViewModelFactoriesEntryPoint .class )
120
+ .getHiltViewModelAssistedMap ()
121
+ .get (modelClass .getName ());
122
+
123
+ if (assistedFactory == null ) {
124
+ if (creationCallback == null ) {
125
+ if (provider == null ) {
126
+ throw new IllegalStateException (
127
+ "Expected the @HiltViewModel-annotated class "
128
+ + modelClass .getName ()
129
+ + " to be available in the multi-binding of "
130
+ + "@HiltViewModelMap"
131
+ + " but none was found." );
132
+ } else {
133
+ return (T ) provider .get ();
134
+ }
135
+ } else {
136
+ // Provider could be null or non-null.
137
+ throw new IllegalStateException (
138
+ "Found creation callback but class "
139
+ + modelClass .getName ()
140
+ + " does not have an assisted factory specified in @HiltViewModel." );
141
+ }
142
+ } else {
143
+ if (provider == null ) {
144
+ if (creationCallback == null ) {
145
+ throw new IllegalStateException (
146
+ "Found @HiltViewModel-annotated class "
147
+ + modelClass .getName ()
148
+ + " using @AssistedInject but no creation callback"
149
+ + " was provided in CreationExtras." );
150
+ } else {
151
+ return (T ) creationCallback .invoke (assistedFactory );
152
+ }
153
+ } else {
154
+ // Creation callback could be null or non-null.
155
+ throw new AssertionError (
156
+ "Found the @HiltViewModel-annotated class "
157
+ + modelClass .getName ()
158
+ + " in both the multi-bindings of "
159
+ + "@HiltViewModelMap and @HiltViewModelAssistedMap." );
160
+ }
101
161
}
102
- ViewModel viewModel = provider .get ();
103
- viewModel .addCloseable (lifecycle ::dispatchOnCleared );
104
- return (T ) viewModel ;
105
162
}
106
163
};
107
164
}
0 commit comments