forked from AbsInt/CompCert
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathReparameterization.v
372 lines (332 loc) · 14 KB
/
Reparameterization.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
(* Reparameterization transform.
This transformation removes constraints on parameters in a model.
After this transform, the parameter inputs to the model block are
assumed to be unconstrained (i.e. arbitrary floats or
reals). Therefore, at each use of a parameter variable, the pass
inserts code to remap the parameter variable into the original
constrained space.
This amounts to a change of variables in an integral, so the pass
also inserts a Jacobian correction factor at the *end* of the model
block that adds the Jacobian to the target.
*)
Require Import List.
Require Import String.
Require Import ZArith.
Require Floats.
Require Integers.
Local Open Scope Z_scope.
Local Open Scope string_scope.
Require StanEnv.
Require Import Stanlight.
Require Errors.
Require Import Clightdefs.
Import Clightdefs.ClightNotations.
Local Open Scope clight_scope.
Notation "'do' X <- A ; B" := (Errors.bind A (fun X => B))
(at level 200, X ident, A at level 100, B at level 200)
: gensym_monad_scope.
Local Open Scope gensym_monad_scope.
(* pmap stores, for each parameter identifier, the code to remap the
value into constrained space *)
Fixpoint transf_expr (pmap: AST.ident -> option (expr -> expr)) (e: Stanlight.expr) {struct e}: Errors.res Stanlight.expr :=
match e with
| Evar id ty =>
match pmap id with
| Some fe =>
match ty with
| Breal => Errors.OK (fe (Evar id Breal))
| _ => Errors.Error (Errors.msg "Reparameterization: parameter loaded with non real type")
end
| None => Errors.OK (Evar id ty)
end
| Ecall e el ty =>
do e <- transf_expr pmap e;
do el <- transf_exprlist pmap el;
Errors.OK (Ecall e el ty)
| Eunop o e ty =>
do e <- transf_expr pmap e;
Errors.OK (Eunop o e ty)
| Ebinop e1 o e2 ty =>
do e1 <- transf_expr pmap e1;
do e2 <- transf_expr pmap e2;
Errors.OK (Ebinop e1 o e2 ty)
| Eindexed e el ty =>
do el <- transf_exprlist pmap el;
match e with
| Evar id _ =>
match pmap id with
| Some fe =>
match ty with
| Breal => Errors.OK (fe (Eindexed e el Breal))
| _ => Errors.Error (Errors.msg "Reparameterization: parameter loaded with non real type")
end
| None => Errors.OK (Eindexed e el ty)
end
| _ => Errors.OK (Eindexed e el ty)
end
| Ecast e ty =>
do e <- transf_expr pmap e;
Errors.OK (Ecast e ty)
| Econst_int a b => Errors.OK (Econst_int a b)
| Econst_float a b => Errors.OK (Econst_float a b)
| Etarget b => Errors.OK (Etarget b)
end
with transf_exprlist (pmap: AST.ident -> option (expr -> expr)) (el: exprlist) {struct el} : Errors.res exprlist :=
match el with
| Enil => Errors.OK Enil
| Econs e el =>
do e <- transf_expr pmap e;
do el <- transf_exprlist pmap el;
Errors.OK (Econs e el)
end.
Fixpoint transf_statement (pmap: AST.ident -> option (expr -> expr))
(s: Stanlight.statement) {struct s} : Errors.res (Stanlight.statement) :=
match s with
| Sskip => Errors.OK (Sskip)
| Sassign e1 o e2 =>
do e1 <- transf_expr pmap e1;
do e2 <- transf_expr pmap e2;
Errors.OK (Sassign e1 o e2)
| Ssequence s1 s2 =>
do s1 <- (transf_statement pmap s1);
do s2 <- (transf_statement pmap s2);
Errors.OK (Ssequence s1 s2)
| Sifthenelse e s1 s2 =>
do e <- (transf_expr pmap e);
do s1 <- (transf_statement pmap s1);
do s2 <- (transf_statement pmap s2);
Errors.OK (Sifthenelse e s1 s2)
| Sfor i e1 e2 s =>
do e1 <- transf_expr pmap e1;
do e2 <- transf_expr pmap e2;
do s <- transf_statement pmap s;
Errors.OK (Sfor i e1 e2 s)
| Starget e =>
do e <- transf_expr pmap e;
Errors.OK (Starget e)
| Stilde e d el =>
Errors.Error (Errors.msg "Reparamterization: detected Stilde, but should have been removed in Sampling")
end.
Definition check_non_param (pmap: AST.ident -> option (expr -> expr)) (v: AST.ident * basic) : Errors.res unit :=
match pmap (fst v) with
| Some _ => Errors.Error (Errors.msg "Reparameterization: function's local shadows a parameter")
| None => Errors.OK tt
end.
Definition vars_check_shadow (p: AST.ident * basic) :=
let '(id, b) := p in
if forallb (fun id' => match (Pos.eq_dec id' id) with
| left _ => false
| right _ => true
end) StanEnv.math_idents then
Errors.OK tt
else
Errors.Error (Errors.msg "Reparameterization: variable shadows global math functions").
Definition transf_function (pmap: AST.ident -> option _) (correction: expr) (f: Stanlight.function): Errors.res
(Stanlight.function) :=
do _ <- Errors.mmap (check_non_param pmap) (f.(fn_vars));
do _ <- Errors.mmap (vars_check_shadow) (f.(fn_vars));
do body <- transf_statement pmap f.(fn_body);
let body := Ssequence body (Starget correction) in
Errors.OK (mkfunction body f.(fn_vars)).
Definition transf_fundef (pmap: AST.ident -> option _) (correction: expr) (fd: Stanlight.fundef) : Errors.res Stanlight.fundef :=
match fd with
| Ctypes.Internal f =>
do tf <- transf_function pmap correction f;
Errors.OK (Ctypes.Internal tf)
| Ctypes.External ef targs tres cc => Errors.OK (Ctypes.External ef targs tres cc)
end.
Definition transf_variable (_: AST.ident) (v: Stanlight.variable): Errors.res Stanlight.variable :=
Errors.OK (mkvariable (v.(vd_type)) (Cidentity)).
(* TODO: the use of this should be removed and basics should be omitted from pr_parameters_vars in syntax *)
Definition valid_equiv_param_type (b1 b2 : basic) :=
match b1, b2 with
| Breal, Breal => true
| Barray Breal z1, Barray Breal z2 =>
if Z.eq_dec z1 z2 then true else false
| _, _ => false
end.
Lemma valid_equiv_param_type_spec b1 b2 :
valid_equiv_param_type b1 b2 = true ->
b1 = b2.
Proof.
destruct b1, b2;
try (simpl; inversion 1; fail); auto;
try (simpl; destruct b1; inversion 1; fail); auto.
simpl. destruct b1, b2; try (inversion 1; fail).
destruct (Z.eq_dec).
{ intros; subst; auto. }
{ inversion 1. }
Qed.
Fixpoint find_parameter {A} (defs: list (AST.ident * AST.globdef fundef variable)) (entry: AST.ident * basic * A) {struct defs}: Errors.res (AST.ident * variable * A) :=
let '(param, b, a) := entry in
match defs with
| nil => Errors.Error (Errors.msg "Reparameterization: parameter missing from list of global definitions")
| (id,def) :: defs =>
match def with
| AST.Gvar v =>
if positive_eq_dec id param then
if valid_equiv_param_type (vd_type (AST.gvar_info v)) b then
Errors.OK (param,v.(AST.gvar_info), a)
else
Errors.Error (Errors.msg "Reparameterization: parameter type inconsistent")
else find_parameter defs entry
| AST.Gfun _ => find_parameter defs entry
end
end.
(* This maps each constraint into a function g : expr -> expr such
that, given an expression e that computes to an unconstrained
value, g(e) is an expression that computes a value that is in the
constrained mapping. These are Stanlight expressions encoding
various mathematical transforms. To see the "mathematical"
description of the transform being applied as a function R -> R,
look at Transforms.v *)
Definition unconstrained_to_constrained_fun (c: constraint) : expr -> expr :=
fun i =>
match c with
| Cidentity => i
| Clower_upper a b =>
let a := Econst_float a Breal in
let b := Econst_float b Breal in
let call := Ecall (Evar $"expit" (Bfunction (Bcons Breal Bnil) Breal)) (Econs i Enil) Breal in
(Ebinop a Plus (Ebinop (Ebinop b Minus a Breal) Times call Breal) Breal)
| Clower a =>
let a := Econst_float a Breal in
let call := Ecall (Evar $"exp" (Bfunction (Bcons Breal Bnil) Breal)) (Econs i Enil) Breal in
(Ebinop call Plus a Breal)
| Cupper b =>
let b := Econst_float b Breal in
let negi := Ebinop (Econst_float Floats.Float.zero Breal) Minus i Breal in
let call := Ecall (Evar $"exp" (Bfunction (Bcons Breal Bnil) Breal)) (Econs negi Enil) Breal in
(Ebinop b Minus call Breal)
end.
Definition unconstrained_to_constrained (v: variable) : option (expr -> expr) :=
let typ := v.(vd_type) in
let constraint := v.(vd_constraint) in
Some (unconstrained_to_constrained_fun constraint).
Fixpoint u_to_c_rewrite_map {A} (parameters: list (AST.ident * variable * A)) {struct parameters}: (AST.ident -> option (expr -> expr)) :=
match parameters with
| nil => fun x => None
| (id, v, _) :: parameters =>
let inner_map := u_to_c_rewrite_map parameters in
fun param =>
if positive_eq_dec id param then (unconstrained_to_constrained v) else (inner_map param)
end.
(* Maps each constraint into a function g : expr -> expr that computes
the (log) jacobian for the transform. That is, given an expression
e that computes to an unconstrained parametervalue, g(e) is an expression
that computes the logarithm of the Jacobian of the constraint mapping for c.
These are Stanlight expressions encoding the Jacobian. To see the
"mathematical" description as a function R -> R, look at
Transforms.v *)
Definition change_of_variable_correction_fun (c: constraint) : option (expr -> expr) :=
match c with
| Cidentity => None
| Clower_upper a b =>
Some (
fun x =>
let a := Econst_float a Breal in
let b := Econst_float b Breal in
let one := Econst_float (Floats.Float.of_int Integers.Int.one) Breal in
let call := Ecall (Evar $"expit" (Bfunction (Bcons Breal Bnil) Breal)) (Econs x Enil) Breal in
let pre_log := (Ebinop (Ebinop b Minus a Breal) Times
(Ebinop call Times (Ebinop one Minus call Breal) Breal) Breal) in
Ecall (Evar $"log" (Bfunction (Bcons Breal Bnil) Breal)) (Econs pre_log Enil) Breal)
| Clower a =>
Some (fun x => x)
| Cupper b =>
Some (fun x => Ebinop (Econst_float Floats.Float.zero Breal) Minus x Breal)
end.
(* Insert the change of variable correction function. For arrays, this
generates an expression that is linear in the size of the array. It
would be better to generate a for loop statement for a large array
of parameters, but this would then require making
change_of_variable_correction a statement instead of an
expression. *)
Definition change_of_variable_correction (i: AST.ident) (v: variable): option expr :=
let typ := v.(vd_type) in
let c := v.(vd_constraint) in
let ofe := change_of_variable_correction_fun c in
match ofe with
| None => None
| Some fe =>
match typ with
(* TODO: we should probably emit loops to handle large arrays rather than unrolling like this *)
| Barray _ sz =>
Some (fold_right (fun ofs e => Ebinop (fe (Eindexed (Evar i typ)
(Econs (Econst_int ofs Bint) Enil) Breal)) Plus e Breal)
(Econst_float Floats.Float.zero Breal)
(count_up_int (Z.to_nat sz)))
| _ => Some (fe (Evar i Breal))
end
end.
Fixpoint collect_corrections {A} (parameters: list (AST.ident * variable * A)) {struct parameters}: expr :=
match parameters with
| nil => Econst_float (Floats.Float.of_int Integers.Int.zero) Breal
| (id,v,_) :: parameters =>
match change_of_variable_correction id v with
| None => collect_corrections parameters
| Some correction => Ebinop correction Plus (collect_corrections parameters) Breal
end
end.
Definition no_param_out (p: AST.ident * basic * option (expr -> expr)) :=
let '(id, _, ofout) := p in
match ofout with
| None => (Errors.OK tt)
| Some _ =>
Errors.Error (Errors.msg "Reparameterization: parameter already has a non-empty output mapping function")
end.
Definition param_check_shadow (p: AST.ident * basic * option (expr -> expr)) :=
let '(id, b, fe) := p in
if forallb (fun id' => match (Pos.eq_dec id' id) with
| left _ => false
| right _ => true
end) StanEnv.math_idents then
Errors.OK tt
else
Errors.Error (Errors.msg "Reparameterization: parameter shadows global math functions").
Definition param_check_sizes (p: AST.ident * basic * option (expr -> expr)) :=
let '(id, b, fe) := p in
match b with
| Barray b z =>
match
Z_lt_ge_dec (-1) z,
Z_lt_ge_dec z (Integers.Int.modulus - 1),
Z_lt_ge_dec z (Integers.Ptrofs.modulus - 1)
with
| left _, left _, left _ => Errors.OK tt
| _, _, _ => Errors.Error (Errors.msg "Reparameterization: array size is negative or bigger than max int")
end
| _ => Errors.OK tt
end.
Fixpoint nodupb {A} (decA: forall x y :A, {x = y} + {x <> y}) (l: list A) : bool :=
match l with
| nil => true
| a :: l =>
if in_dec decA a l then
false
else
nodupb decA l
end.
Definition check_nodup_params (l: list AST.ident) : Errors.res unit :=
if nodupb Pos.eq_dec l then
Errors.OK tt
else
Errors.Error (Errors.msg "Reparameterization: duplicate paramter id").
Definition transf_program(p: Stanlight.program): Errors.res Stanlight.program :=
do _ <- Errors.mmap (no_param_out) p.(pr_parameters_vars);
do _ <- Errors.mmap (param_check_shadow ) (p.(pr_parameters_vars));
do _ <- check_nodup_params (map (fun '(id, _, _) => id) (p.(pr_parameters_vars)));
do _ <- Errors.mmap (param_check_sizes) (p.(pr_parameters_vars));
do parameters <- Errors.mmap (find_parameter p.(pr_defs)) p.(pr_parameters_vars);
let pmap := u_to_c_rewrite_map parameters in
let correction := collect_corrections parameters in
let pr_parameters_vars' := List.map (fun '(id, v, f) =>
(id, vd_type v,
Some (fun x => (unconstrained_to_constrained_fun (vd_constraint v) x))))
parameters in
do p1 <- AST.transform_partial_program2 (fun id => transf_fundef pmap correction) transf_variable p;
Errors.OK {|
Stanlight.pr_defs := AST.prog_defs p1;
Stanlight.pr_parameters_vars := pr_parameters_vars';
Stanlight.pr_data_vars := p.(pr_data_vars);
|}.