forked from AbsInt/CompCert
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDenotationalSimulationAdditive.v
334 lines (295 loc) · 12.5 KB
/
DenotationalSimulationAdditive.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
(* Denotational correctness proof for the additive constant optimization.
As the name suggests, the transformation drops addition of
constants from a model block -- the total constants dropped can be
a function of the *data* input to the program, but it cannot depend
upon the parameters.
Let p be a program and tp be the transformed program. We assume
there is a function target_const : list val -> R such that
target_const(d) is the negation of the constant dropped by the
transform when the data is d.
The proof here takes as input a forward simulation between p and tp
such that, with data d and parameters ρ, when the "testval" of the
final state of the simulation is t on program p, the "testval" for
tp is target_const (d) + t.
From this we deduce that
log_density_of_program p d ρ = log_density_of_program tp d ρ - target_const d
Hence
density_of_program p d ρ = c * density_of_program tp d ρ
where c is exp(-target_const d)
Since integration is a linear operator, this means that both the
unnormalized distribution and the normalization constant of target
program will be scaled by c, hence when we divide, the scalar c
cancels out, giving the same distribution as the original program p.
*)
Require Import Coqlib Errors Maps String.
Local Open Scope string_scope.
Require Import Integers Floats Values AST Memory Builtins Events Globalenvs.
Require Import Ctypes Cop Stanlight.
Require Import Smallstep.
Require Import Linking.
Require Import IteratedRInt.
Require Import StanEnv.
Require Vector.
Require Import Clightdefs.
Import Clightdefs.ClightNotations.
Local Open Scope clight_scope.
Require ClassicalEpsilon.
Require Import Reals.
From Coq Require Import Reals Psatz ssreflect ssrbool Utf8.
Require Import Ssemantics.
Section DENOTATIONAL_SIMULATION.
Variable prog: Stanlight.program.
Variable tprog: Stanlight.program.
(* prog is assumed to be safe/well-defined on data/params satisfying a predicate P *)
Lemma inhabited_initial :
∀ data params t, is_safe prog data params -> ∃ s, Smallstep.initial_state (semantics prog data params t) s.
Proof.
intros data params t Hsafe. destruct Hsafe as (Hex&_). eapply Hex.
Qed.
Variable target_const : list val -> R.
Variable transf_correct:
forall data params t,
genv_has_mathlib (globalenv prog) ->
is_safe prog data params ->
forward_simulation (Ssemantics.semantics prog data params (IRF t)) (Ssemantics.semantics tprog data params (IRF (target_const data + t))).
Variable parameters_preserved:
flatten_parameter_variables tprog = flatten_parameter_variables prog.
Variable external_funct_preserved:
match_external_funct (globalenv prog) (globalenv tprog).
Variable global_env_equiv :
Senv.equiv (globalenv prog) (globalenv tprog).
Variable symbols_preserved:
forall id,
Genv.find_symbol (globalenv tprog) id = Genv.find_symbol (globalenv prog) id.
Lemma tprog_genv_has_mathlib :
genv_has_mathlib (globalenv prog) ->
genv_has_mathlib (globalenv tprog).
Proof.
destruct 1.
split; rewrite /genv_exp_spec/genv_log_spec/genv_expit_spec/genv_normal_lpdf_spec/genv_normal_lupdf_spec;
rewrite /genv_cauchy_lpdf_spec/genv_cauchy_lupdf_spec;
rewrite ?symbols_preserved.
intuition.
{ destruct GENV_EXP as (loc&?). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_EXPIT as (loc&?). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_LOG as (loc&?). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_NORMAL_LPDF as (loc&Hnor). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_NORMAL_LUPDF as (loc&Hnor). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_CAUCHY_LPDF as (loc&Hnor). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
{ destruct GENV_CAUCHY_LUPDF as (loc&Hnor). exists loc. split; first by intuition.
eapply external_funct_preserved; intuition eauto. }
Qed.
Lemma match_flatten_parameter_variables (p tp : program) f :
match_program f eq p tp ->
pr_parameters_vars p = pr_parameters_vars tp ->
flatten_parameter_variables tp = flatten_parameter_variables p.
Proof.
intros Hmatch Heq.
unfold flatten_parameter_variables. simpl.
unfold flatten_ident_variable_list.
rewrite Heq.
f_equal. f_equal.
apply List.map_ext.
intros ((id&b)&f').
f_equal.
unfold lookup_def_ident.
destruct Hmatch as (H1&H2).
simpl in H1.
edestruct (@list_find_fst_forall2 _ (AST.globdef fundef variable)
((fun '(id', v) => Pos.eq_dec id id' && is_gvar v))) as [Hleft|Hright]; first eauto.
{ intros ?? (?&?); auto. }
{ intros (?&?) (?&?). inversion 1 as [Hfst Hglob].
simpl in Hfst; subst. simpl in Hglob. inversion Hglob. subst.
* rewrite //=.
* subst. rewrite //=.
}
{ simpl. destruct Hleft as (id'&g1&g2&->&->&Hident).
inversion Hident as [Hfst_eq Hglob]. simpl in Hglob.
inversion Hglob; auto.
subst. inversion H. congruence. }
{ destruct Hright as (->&->). auto. }
Qed.
Lemma match_program_external_funct (p tp : program) transf_fundef :
match_program (fun ctx f tf => tf = transf_fundef f) eq p tp ->
(∀ ef tyargs tyres cconv,
transf_fundef (Ctypes.External ef tyargs tyres cconv) =
Ctypes.External ef tyargs tyres cconv) ->
(∀ f ef tyargs tyres cconv,
transf_fundef (Internal f) <> External ef tyargs tyres cconv) ->
match_external_funct (globalenv p) (globalenv tp).
Proof.
intros Hmatch Hext Hint.
- unfold match_external_funct, sub_external_funct.
split.
* intros. rewrite -Hext. eapply @Genv.find_funct_transf; eauto.
* intros.
edestruct (Genv.find_funct_transf_rev Hmatch) as (p'&->&Htransf); eauto.
destruct p'; simpl in Htransf; try congruence.
{ exfalso. eapply Hint. eauto. }
rewrite Hext in Htransf.
inversion Htransf. subst. eauto.
Qed.
Lemma dimen_preserved:
parameter_dimension tprog = parameter_dimension prog.
Proof. rewrite /parameter_dimension/flatten_parameter_constraints. rewrite parameters_preserved //. Qed.
Section has_mathlib.
Variable MATH: genv_has_mathlib (globalenv prog).
Lemma returns_target_value_fsim data params t:
is_safe prog data params ->
returns_target_value prog data params (IRF t) ->
returns_target_value tprog data params (IRF (target_const data + t)).
Proof.
intros Hsafe.
intros (s1&s2&Hinit&Hstar&Hfinal).
destruct (transf_correct data params t) as [index order match_states props]; eauto.
edestruct (fsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
edestruct (simulation_star) as (?&s2'&Hstar'&Hmatch2); eauto.
eapply (fsim_match_final_states) in Hmatch2; eauto.
exists s1', s2'; auto.
Qed.
Lemma returns_target_value_bsim data params t:
is_safe prog data params ->
returns_target_value tprog data params (IRF (target_const data + t)) ->
returns_target_value prog data params (IRF t).
Proof.
intros Hsafe (s1&s2&Hinit&Hstar&Hfinal).
specialize (transf_correct data params t) as Hfsim.
apply forward_to_backward_simulation in Hfsim as Hbsim;
auto using semantics_determinate, semantics_receptive.
destruct Hbsim as [index order match_states props].
assert (∃ s10, Smallstep.initial_state (semantics prog data params (IRF t)) s10) as (s10&?).
{ apply inhabited_initial; eauto. }
edestruct (bsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
edestruct (bsim_E0_star) as (?&s2'&Hstar'&Hmatch2); eauto.
{ eapply Hsafe; eauto. }
eapply (bsim_match_final_states) in Hmatch2 as (s2''&?&?); eauto; last first.
{ eapply star_safe; last eapply Hsafe; eauto. }
exists s1', s2''. intuition eauto.
{ eapply star_trans; eauto. }
Qed.
Lemma log_density_equiv data params :
is_safe prog data params ->
target_const data + log_density_of_program prog data params = log_density_of_program tprog data params.
Proof.
intros HP.
rewrite {1}/log_density_of_program.
rewrite /pred_to_default_fun.
destruct (ClassicalEpsilon.excluded_middle_informative) as [(v&Hreturns)|Hne].
{ destruct (ClassicalEpsilon.constructive_indefinite_description) as [x Hx].
symmetry.
replace x with (IRF (IFR x)) in Hx; last by (rewrite IRF_IFR_inv).
exploit returns_target_value_fsim; eauto.
intros Heq%log_density_of_program_trace. rewrite Heq.
rewrite IFR_IRF_inv //.
}
symmetry.
rewrite {1}/log_density_of_program.
rewrite /pred_to_default_fun.
destruct (ClassicalEpsilon.excluded_middle_informative) as [(v&Hreturns)|Hne']; auto.
{
exfalso. apply Hne.
assert (v = IRF (target_const data + (IFR v - target_const data))) as Heq.
{ rewrite -{1}(IRF_IFR_inv v). f_equal. nra. }
rewrite Heq in Hreturns.
eexists.
exploit returns_target_value_bsim; eauto.
}
{ exfalso. eapply Hne. eapply HP. }
Qed.
Lemma safe_data_preserved :
∀ data, safe_data prog data -> safe_data tprog data.
Proof.
intros data Hsafe.
rewrite /safe_data. intros params Hin.
assert (Hin': in_list_rectangle params (parameter_list_rect prog)).
{ move:Hin. rewrite /parameter_list_rect/flatten_parameter_constraints parameters_preserved //. }
specialize (Hsafe _ Hin').
rewrite /is_safe. split.
{ intros t.
unshelve (edestruct Hsafe as ((s&Hinit)&_)).
{ exact (IRF (-target_const data + IFR t)). }
exploit (transf_correct data); eauto. intros Hfsim.
destruct Hfsim. edestruct fsim_match_initial_states as (ind&s'&?); eauto.
exists s'. intuition.
}
split.
{
intros t s Hinit.
epose proof (transf_correct data (map R2val params) (- target_const data + IFR t)) as Hfsim.
apply forward_to_backward_simulation in Hfsim as Hbsim;
auto using semantics_determinate, semantics_receptive.
edestruct Hbsim as [index order match_states props].
eassert (∃ s10, Smallstep.initial_state (semantics prog data (map (λ r, Vfloat (IRF r)) params) _) s10)
as (s10&?).
{ apply inhabited_initial; eauto. }
edestruct (bsim_match_initial_states) as (?&s1'&Hinit'&Hmatch1); eauto.
eapply bsim_safe; eauto.
{ assert ((IRF (target_const data + (- target_const data + IFR t))) = t) as Heq.
{ rewrite -{2}(IRF_IFR_inv t). f_equal. nra. }
rewrite Heq in props. eauto. }
apply Hsafe. eauto.
}
{
edestruct Hsafe as (?&?&Hret). destruct Hret as (t&?).
eexists. apply returns_target_value_fsim; eauto.
erewrite IRF_IFR_inv; eauto.
}
Qed.
End has_mathlib.
Lemma parameter_list_rect_preserved :
parameter_list_rect tprog = parameter_list_rect prog.
Proof.
rewrite /parameter_list_rect/flatten_parameter_constraints parameters_preserved //.
Qed.
Lemma denotational_preserved :
denotational_refinement tprog prog.
Proof.
exists (dimen_preserved).
split; [| split; [| split]].
- intros data Hsafe ?. apply safe_data_preserved; auto.
- intros data rt vt Hsafe Hmath Hwf.
rewrite /is_program_distribution/is_program_normalizing_constant/is_unnormalized_program_distribution.
intros (vnum&vnorm&Hneq0&His_norm&His_num&Hdiv).
exists (exp (target_const data) * vnum), (exp (target_const data) * vnorm). repeat split; auto.
{ specialize (exp_pos (target_const data)). nra. }
{
rewrite parameter_list_rect_preserved.
eapply is_IIRInt_list_ext; last first.
{ apply is_IIRInt_list_scal; eauto. }
{ intros x Hin. rewrite /program_normalizing_constant_integrand.
{ rewrite /density_of_program/= -log_density_equiv //; try eapply Hsafe; eauto.
rewrite exp_plus //=. }
}
{ eauto. }
}
{
rewrite parameter_list_rect_preserved.
eapply is_IIRInt_list_ext; last first.
{ apply is_IIRInt_list_scal; eauto. }
{ intros x Hin. rewrite /unnormalized_program_distribution_integrand.
{ rewrite /density_of_program/= -log_density_equiv //; try eapply Hsafe; eauto.
rewrite exp_plus //=. rewrite /Hierarchy.scal/=/Hierarchy.mult /=.
rewrite Rmult_assoc.
f_equal. f_equal.
rewrite /eval_param_map_list //=.
rewrite /flatten_parameter_out parameters_preserved //.
f_equal. eapply map_ext.
intros (r&f) => /=. f_equal.
apply eval_expr_fun_match; eauto.
}
}
{ eauto. }
}
{ rewrite Hdiv. field.
specialize (exp_pos (target_const data)). nra.
}
- intros. apply tprog_genv_has_mathlib; auto.
- rewrite /parameter_list_rect/flatten_parameter_constraints. rewrite parameters_preserved //.
Qed.
End DENOTATIONAL_SIMULATION.