-
-
Notifications
You must be signed in to change notification settings - Fork 46
/
Typechecker.ml
1971 lines (1827 loc) · 78 KB
/
Typechecker.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(** a type/semantic checker for Stan ASTs
Functions which begin with "check_" return a typed version of their input
Functions which begin with "verify_" return unit if a check succeeds, or else
throw an Errors.SemanticError exception.
Other functions which begin with "infer"/"calculate" vary. Usually they return
a value, but a few do have error conditions.
All Error.SemanticError exceptions are caught by check_program
which turns the ast or exception into a Result.t for external usage
A type environment (Env.t) is used to hold variables and functions, including
stan math functions. This is a functional map, meaning it is handled immutably.
*)
open Core
open Core.Poly
open Middle
open Ast
module Env = Environment
(* we only allow errors raised by this function *)
let error e = raise (Errors.SemanticError e)
(* warnings are built up in a list *)
let warnings : Warnings.t list ref = ref []
let add_warning (span : Location_span.t) (message : string) =
warnings := (span, message) :: !warnings
let attach_warnings x = (x, List.rev !warnings)
(* model name - don't love this here *)
let model_name = ref ""
let check_that_all_functions_have_definition = ref true
type function_indicator =
| NotInFunction
| NonReturning of unit Fun_kind.suffix
| Returning of unit Fun_kind.suffix * UnsizedType.t
(* Record structure holding flags and other markers about context to be
used for error reporting. *)
type context_flags_record =
{ current_block: Env.originblock
; in_toplevel_decl: bool
; containing_function: function_indicator
; loop_depth: int }
let in_function cf = cf.containing_function <> NotInFunction
let in_rng_function cf =
match cf.containing_function with
| NonReturning FnRng | Returning (FnRng, _) -> true
| _ -> false
let in_lp_function cf =
match cf.containing_function with
| NonReturning FnTarget | Returning (FnTarget, _) -> true
| _ -> false
let in_jacobian_function cf =
match cf.containing_function with
| NonReturning FnJacobian | Returning (FnJacobian, _) -> true
| _ -> false
let in_udf_distribution cf =
match cf.containing_function with
| NonReturning (FnLpdf ())
|Returning (FnLpdf (), _)
|NonReturning (FnLpmf ())
|Returning (FnLpmf (), _) ->
true
| _ -> false
let context block =
{ current_block= block
; in_toplevel_decl= false
; containing_function= NotInFunction
; loop_depth= 0 }
let rec calculate_autodifftype cf origin ut =
let ut, _ = UnsizedType.unwind_array_type ut in
match (origin, ut) with
| _, UTuple ts ->
UnsizedType.TupleAD (List.map ~f:(calculate_autodifftype cf origin) ts)
| Env.(Param | TParam | Model | Functions), _
when not (UnsizedType.is_discrete_type ut || cf.current_block = GQuant) ->
UnsizedType.AutoDiffable
| _, _ -> DataOnly
let arg_type x = (x.emeta.ad_level, x.emeta.type_)
let get_arg_types = List.map ~f:arg_type
let type_of_expr_typed ue = ue.emeta.type_
let has_int_type ue = ue.emeta.type_ = UInt
let has_int_array_type ue = ue.emeta.type_ = UArray UInt
let rec name_of_lval lv =
match lv.lval with
| LVariable id -> id.name
| LTupleProjection (lv, _) -> name_of_lval lv
| LIndexed (lv, _) -> name_of_lval lv
let has_int_or_real_type ue =
match ue.emeta.type_ with UInt | UReal -> true | _ -> false
(* -- General checks ---------------------------------------------- *)
let reserved_keywords =
(* parser stops most keywords currently in use, but we still have some extra
reserved for the future *)
[ "generated"; "quantities"; "transformed"; "repeat"; "until"; "then"; "true"
; "false"; "typedef"; "struct"; "var"; "export"; "extern"; "static"; "auto" ]
let verify_identifier id : unit =
if id.name = "jacobian" then
add_warning id.id_loc
"Variable name 'jacobian' will be a reserved word starting in Stan 2.38. \
Please rename it!";
if id.name = !model_name then
Semantic_error.ident_is_model_name id.id_loc id.name |> error;
if
String.is_suffix id.name ~suffix:"__"
|| List.mem reserved_keywords id.name ~equal:String.equal
then Semantic_error.ident_is_keyword id.id_loc id.name |> error
(** verify that the variable being declared is previous unused.
allowed to shadow StanLib *)
let verify_name_fresh_var loc tenv name =
if Utils.is_unnormalized_distribution name then
Semantic_error.ident_has_unnormalized_suffix loc name |> error
else if
List.exists (Env.find tenv name) ~f:(function
| {kind= `Variable _; _} -> true
| _ -> false (* user variables can shadow function names *))
then Semantic_error.ident_in_use loc name |> error
(** verify that the variable being declared is previous unused. *)
let verify_name_fresh_udf loc tenv name =
if
(* variadic functions are currently not in math sigs and aren't
overloadable due to their separate typechecking *)
Stan_math_signatures.is_reduce_sum_fn name
|| Stan_math_signatures.is_stan_math_variadic_function_name name
then Semantic_error.ident_is_stanmath_name loc name |> error
else if Utils.is_unnormalized_distribution name then
Semantic_error.udf_is_unnormalized_fn loc name |> error
else if
(* if a variable is already defined with this name
- not really possible as all functions are defined before data,
but future-proofing is good *)
List.exists
~f:(function {kind= `Variable _; _} -> true | _ -> false)
(Env.find tenv name)
then Semantic_error.ident_in_use loc name |> error
(** Checks that a variable/function name:
- a function/identifier does not have the _lupdf/_lupmf suffix
- is not already in use (for now)
*)
let verify_name_fresh tenv id ~is_udf =
if is_udf then verify_name_fresh_udf id.id_loc tenv id.name
else verify_name_fresh_var id.id_loc tenv id.name
let is_of_compatible_return_type rt1 srt2 =
UnsizedType.(
match (rt1, srt2) with
| Void, _ -> true
| ReturnType _, Complete -> true
| _ -> false)
(* -- Expressions ------------------------------------------------- *)
let check_ternary_if loc pe te fe =
let promote expr type_ ad_level =
if
(not (UnsizedType.equal expr.emeta.type_ type_))
|| UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0
then
{ expr= Promotion (expr, UnsizedType.internal_scalar type_, ad_level)
; emeta= {expr.emeta with type_; ad_level} }
else expr in
match
( pe.emeta.type_
, UnsizedType.common_type (te.emeta.type_, fe.emeta.type_)
, expr_ad_lub [pe; te; fe] )
with
| UInt, Some type_, Some ad_level when not (UnsizedType.is_fun_type type_) ->
mk_typed_expression
~expr:
(TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level))
~ad_level ~type_ ~loc
| _, _, _ ->
Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_
fe.emeta.type_
|> error
let match_to_rt_option = function
| SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt
| _ -> None
let stan_math_return_type name arg_tys =
match
Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures name
with
| Some {return_type; _} -> Some (UnsizedType.ReturnType return_type)
| None when Stan_math_signatures.is_reduce_sum_fn name ->
Some (UnsizedType.ReturnType UReal)
| None ->
SignatureMismatch.matching_stanlib_function name arg_tys
|> match_to_rt_option
let operator_stan_math_return_type op arg_tys =
match (op, arg_tys) with
| Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] ->
Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion])
| IntDivide, _ -> None
| _ ->
Stan_math_signatures.operator_to_stan_math_fns op
|> List.filter_map ~f:(fun name ->
SignatureMismatch.matching_stanlib_function name arg_tys
|> function
| SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p)
| _ -> None)
|> List.hd
let assignmentoperator_stan_math_return_type assop arg_tys =
(match assop with
| Operator.Divide ->
SignatureMismatch.matching_stanlib_function "divide" arg_tys
|> match_to_rt_option
| Plus | Minus | Times | EltTimes | EltDivide ->
operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst
| _ -> None)
|> Option.bind ~f:(function
| ReturnType rtype
when rtype = snd (List.hd_exn arg_tys)
&& not
((assop = Operator.EltTimes || assop = Operator.EltDivide)
&& UnsizedType.is_scalar_type rtype) ->
Some UnsizedType.Void
| _ -> None)
let check_binop loc op le re =
let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in
match (rt, expr_ad_lub [le; re]) with
| Some (ReturnType type_, [p1; p2]), Some ad_level ->
mk_typed_expression
~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2))
~ad_level ~type_ ~loc
| _ ->
Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_
|> error
let check_prefixop loc op te =
let rt = operator_stan_math_return_type op [arg_type te] in
match rt with
| Some (ReturnType type_, _) ->
mk_typed_expression
~expr:(PrefixOp (op, te))
~ad_level:te.emeta.ad_level ~type_ ~loc
| _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error
let check_postfixop loc op te =
let rt = operator_stan_math_return_type op [arg_type te] in
match rt with
| Some (ReturnType type_, _) ->
mk_typed_expression
~expr:(PostfixOp (te, op))
~ad_level:te.emeta.ad_level ~type_ ~loc
| _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error
let check_id cf loc tenv id =
match Env.find tenv (Utils.stdlib_distribution_name id.name) with
| [] ->
Semantic_error.ident_not_in_scope loc id.name
(Env.nearest_ident tenv id.name)
|> error
| {kind= `StanMath; _} :: _ ->
( calculate_autodifftype cf MathLibrary UMathLibraryFunction
, UnsizedType.UMathLibraryFunction )
| {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _
when cf.in_toplevel_decl ->
Semantic_error.non_data_variable_size_decl loc |> error
| _ :: _
when Utils.is_unnormalized_distribution id.name
&& not
((in_udf_distribution cf || in_lp_function cf)
|| cf.current_block = Model) ->
Semantic_error.invalid_unnormalized_fn loc |> error
| {kind= `Variable {origin; _}; type_} :: _ ->
(calculate_autodifftype cf origin type_, type_)
| { kind= `UserDefined | `UserDeclared _
; type_= UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern) }
:: _ ->
let type_ =
UnsizedType.UFun
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
(calculate_autodifftype cf Functions type_, type_)
| {kind= `UserDefined | `UserDeclared _; type_} :: _ ->
(calculate_autodifftype cf Functions type_, type_)
let check_variable cf loc tenv id =
let ad_level, type_ = check_id cf loc tenv id in
mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc
let get_consistent_types type_ es =
let f state e =
Result.bind state ~f:(fun ty ->
match UnsizedType.common_type (ty, e.emeta.type_) with
| Some ty -> Ok ty
| None -> Error (ty, e.emeta)) in
List.fold ~init:(Ok type_) ~f es
|> Result.map ~f:(fun ty ->
let ad =
expr_ad_lub es |> Option.value_exn
(* correctness: Result.Ok case only contains tuples of same lengths, expr_ad_lub cannot fail *)
in
let promotions =
List.map (get_arg_types es)
~f:(Promotion.get_type_promotion_exn (ad, ty)) in
(ad, ty, promotions))
let check_array_expr loc es =
match es with
| [] ->
(* NB: This is actually disallowed by parser *)
Semantic_error.empty_array loc |> error
| {emeta= {type_; _}; _} :: _ -> (
match get_consistent_types type_ es with
| Error (ty, meta) ->
Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error
| Ok (ad_level, type_, promotions) ->
let type_ = UnsizedType.UArray type_ in
mk_typed_expression
~expr:(ArrayExpr (Promotion.promote_list es promotions))
~ad_level ~type_ ~loc)
let check_rowvector loc es =
match es with
| {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types URowVector es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplexRowVector then UComplexMatrix else UMatrix)
~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error)
| {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types UComplexRowVector es with
| Ok (ad_level, _, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level ~type_:UComplexMatrix ~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error)
| _ -> (
match get_consistent_types UReal es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplex then UComplexRowVector else URowVector)
~loc
| Error (_, meta) ->
Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error)
(* index checking *)
let indexing_type idx =
match idx with
| Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single
| _ -> `Multi
let is_multiindex i =
match indexing_type i with `Single -> false | `Multi -> true
let inferred_unsizedtype_of_indexed ~loc ut indices =
let rec aux type_ idcs =
let vec, rowvec, scalar =
if UnsizedType.is_complex_type type_ then
UnsizedType.(UComplexVector, UComplexRowVector, UComplex)
else (UVector, URowVector, UReal) in
match (type_, idcs) with
| _, [] -> type_
| UnsizedType.UArray type_, `Single :: tl -> aux type_ tl
| UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray
| (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single]
|(UMatrix | UComplexMatrix), [`Single; `Single] ->
scalar
| ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix
| UComplexRowVector )
, [`Multi] )
|(UMatrix | UComplexMatrix), [`Multi; `Multi] ->
type_
| (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec
| (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec
| (UMatrix | UComplexMatrix), _ :: _ :: _ :: _
|(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _
|( (UInt | UReal | UComplex | UFun _ | UMathLibraryFunction | UTuple _)
, _ :: _ ) ->
Semantic_error.not_indexable loc ut (List.length indices) |> error in
aux ut (List.map ~f:indexing_type indices)
let inferred_ad_type_of_indexed at ut uindices =
UnsizedType.fill_adtype_for_type
(* correctness: index expressions only contain int types,
so lub_ad_tupe should never be [None]. *)
(UnsizedType.lub_ad_type
(at
:: List.map
~f:(function
| All -> UnsizedType.DataOnly
| Single ue1 | Upfrom ue1 | Downfrom ue1 -> ue1.emeta.ad_level
| Between (ue1, ue2) ->
UnsizedType.lub_ad_type
[ue1.emeta.ad_level; ue2.emeta.ad_level]
|> Option.value_exn)
uindices)
|> Option.value_exn)
ut
(* function checking *)
let verify_conddist_name loc id =
if
List.exists
~f:(fun x -> String.is_suffix id.name ~suffix:x)
Utils.conditioning_suffices
then ()
else Semantic_error.conditional_notation_not_allowed loc |> error
let verify_fn_conditioning loc id =
if
List.exists
~f:(fun suffix -> String.is_suffix id.name ~suffix)
Utils.conditioning_suffices
then Semantic_error.conditioning_required loc |> error
(** `Target+=` can only be used in model and functions
with right suffix (same for tilde etc)
*)
let verify_fn_target_plus_equals cf loc id =
if String.is_suffix id.name ~suffix:"_lp" then
if cf.current_block = TParam then
add_warning loc
"Using _lp functions in transformed parameters is deprecated and will \
be disallowed in Stan 2.39. Use an _jacobian function instead, as \
this allows change of variable adjustments which are conditionally \
enabled by the algorithms."
else if in_lp_function cf || cf.current_block = Model then ()
else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
let verify_fn_jacobian_plus_equals cf loc id =
if
String.is_suffix id.name ~suffix:"_jacobian"
&& not (in_jacobian_function cf || cf.current_block = TParam)
then Semantic_error.jacobian_plusequals_not_allowed loc |> error
(** Rng functions cannot be used in Tp or Model and only
in function defs with the right suffix
*)
let verify_fn_rng cf loc id =
if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then
Semantic_error.invalid_decl_rng_fn loc |> error
else if
String.is_suffix id.name ~suffix:"_rng"
&& ((in_function cf && not (in_rng_function cf))
|| cf.current_block = TParam || cf.current_block = Model)
then Semantic_error.invalid_rng_fn loc |> error
(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs
or the model block
*)
let verify_unnormalized cf loc id =
if
Utils.is_unnormalized_distribution id.name
&& not (in_udf_distribution cf || cf.current_block = Model)
then Semantic_error.invalid_unnormalized_fn loc |> error
let mk_fun_app ~is_cond_dist ~loc kind name args ~type_ : Ast.typed_expression =
let fn =
if is_cond_dist then CondDistApp (kind, name, args)
else FunApp (kind, name, args) in
let ad_type =
if UnsizedType.is_discrete_type type_ then UnsizedType.DataOnly
else if
UnsizedType.any_autodiff (List.map ~f:(fun x -> x.emeta.ad_level) args)
then AutoDiffable
else DataOnly in
mk_typed_expression ~expr:fn ~loc ~type_
~ad_level:(UnsizedType.fill_adtype_for_type ad_type type_)
let check_normal_fn ~is_cond_dist loc tenv id es =
match Env.find tenv (Utils.normalized_name id.name) with
| {kind= `Variable _; _} :: _
(* variables can sometimes shadow stanlib functions, so we have to check this *)
when not
(Stan_math_signatures.is_stan_math_function_name
(Utils.normalized_name id.name)) ->
Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error
| [] ->
(match Utils.split_distribution_suffix id.name with
| Some (prefix, suffix) -> (
let known_families =
List.map
~f:(fun (_, y, _, _) -> y)
Stan_math_signatures.distributions in
let is_known_family s =
List.mem known_families s ~equal:String.equal in
match suffix with
| ("lpmf" | "lupmf") when Env.mem tenv (prefix ^ "_lpdf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| ("lpdf" | "lupdf") when Env.mem tenv (prefix ^ "_lpmf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| _ ->
if
is_known_family prefix
&& List.mem ~equal:String.equal
Utils.cumulative_distribution_suffices_w_rng suffix
then
Semantic_error
.returning_fn_expected_undeclared_dist_suffix_found loc
(prefix, suffix)
else
Semantic_error.returning_fn_expected_undeclaredident_found loc
id.name
(Env.nearest_ident tenv id.name))
| None ->
Semantic_error.returning_fn_expected_undeclaredident_found loc id.name
(Env.nearest_ident tenv id.name))
|> error
| _ (* a function *) -> (
(* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types.
This is not needed until UDFs can be higher-order, as it is special cased for
variadic functions
*)
match
SignatureMismatch.matching_function tenv id.name (get_arg_types es)
with
| UniqueMatch (Void, _, _) ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> error
| UniqueMatch (ReturnType ut, fnk, promotions) ->
mk_fun_app ~is_cond_dist ~loc
(fnk (Fun_kind.suffix_from_name id.name))
id
(Promotion.promote_list es promotions)
~type_:ut
| AmbiguousMatch sigs ->
Semantic_error.ambiguous_function_promotion loc id.name
(Some (List.map ~f:type_of_expr_typed es))
sigs
|> error
| SignatureErrors (l, b) ->
es
|> List.map ~f:(fun e -> e.emeta.type_)
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
|> error)
(** Given a constraint function [matches], find any signature which exists
Returns the first [Ok] if any exist, or else [Error]
*)
let find_matching_first_order_fn tenv matches fname =
let candidates =
Utils.stdlib_distribution_name fname.name
|> Env.find tenv |> List.map ~f:matches in
let ok, errs = List.partition_map candidates ~f:Result.to_either in
match SignatureMismatch.unique_minimum_promotion ok with
| Ok a -> SignatureMismatch.UniqueMatch a
| Error (Some promotions) ->
List.filter_map promotions ~f:(function
| UnsizedType.UFun (args, rt, _, _) -> Some (rt, args)
| _ -> None)
|> AmbiguousMatch
| Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs)
let make_function_variable cf loc id = function
| UnsizedType.UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern) ->
let type_ =
UnsizedType.UFun
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype cf Functions type_)
~type_ ~loc
| UnsizedType.UFun _ as type_ ->
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype cf Functions type_)
~type_ ~loc
| type_ ->
Common.ICE.internal_compiler_error
[%message
"Attempting to create function variable out of "
(type_ : UnsizedType.t)]
let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list)
=
if Stan_math_signatures.is_stan_math_variadic_function_name id.name then
check_variadic ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_reduce_sum_fn id.name then
check_reduce_sum ~is_cond_dist loc cf tenv id tes
else check_normal_fn ~is_cond_dist loc tenv id tes
(** Reduce sum is a special case, even compared to the other
variadic functions, because it is polymorphic in the type of the
first argument. The first, fourth, and fifth arguments must agree,
which is too complicated to be captured declaratively. *)
and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
let basic_mismatch () =
let mandatory_args =
UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in
let mandatory_fun_args =
UnsizedType.
[(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args UReal (get_arg_types tes) in
let matching remaining_es fn =
match fn with
| Env.{type_= UnsizedType.UFun (sliced_arg_fun :: _, _, _, _) as ftype; _}
->
let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in
let mandatory_fun_args =
[sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args UReal arg_types
| _ -> basic_mismatch () in
match tes with
| {expr= Variable fname; _}
:: ({emeta= {type_= slice_type; _}; _} :: _ as remaining_es) -> (
let slice_type, n = UnsizedType.unwind_array_type slice_type in
if n = 0 then
Semantic_error.illtyped_reduce_sum_not_array loc slice_type |> error
else if
not
@@ List.mem Stan_math_signatures.reduce_sum_slice_types slice_type
~equal:( = )
then Semantic_error.illtyped_reduce_sum_slice loc slice_type |> error;
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch (ftype, promotions) ->
(* a valid signature exists *)
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id
(Promotion.promote_list tes promotions)
~type_:UnsizedType.UReal
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error)
| _ ->
let expected_args, err =
basic_mismatch () |> Result.error |> Option.value_exn in
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error
and check_variadic ~is_cond_dist loc cf tenv id tes =
let Stan_math_signatures.
{control_args; required_fn_args; required_fn_rt; return_type} =
Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name
in
let matching remaining_es Env.{type_= ftype; _} =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args required_fn_rt arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch (ftype, promotions) ->
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id
(Promotion.promote_list tes promotions)
~type_:return_type
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args required_fn_rt err
|> error)
| _ ->
let expected_args, err =
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args required_fn_rt (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args required_fn_rt err
|> error
and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) =
let name_check =
if is_cond_dist then verify_conddist_name else verify_fn_conditioning in
let res = check_fn ~is_cond_dist loc cf tenv id es in
verify_identifier id;
name_check loc id;
verify_fn_target_plus_equals cf loc id;
verify_fn_jacobian_plus_equals cf loc id;
verify_fn_rng cf loc id;
verify_unnormalized cf loc id;
res
and check_indexed loc cf tenv e indices =
let tindices = List.map ~f:(check_index cf tenv) indices in
let te = check_expression cf tenv e in
let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in
let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level type_ tindices in
mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc
and check_index cf tenv = function
| All -> All
(* Check that indexes have int (container) type *)
| Single e ->
let te = check_expression cf tenv e in
if has_int_type te || has_int_array_type te then Single te
else
Semantic_error.int_intarray_or_range_expected te.emeta.loc
te.emeta.type_
|> error
| Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom
| Downfrom e ->
check_expression_of_int_type cf tenv e "Range bound" |> Downfrom
| Between (e1, e2) ->
let le = check_expression_of_int_type cf tenv e1 "Range bound" in
let ue = check_expression_of_int_type cf tenv e2 "Range bound" in
Between (le, ue)
and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) :
Ast.typed_expression =
let loc = emeta.loc in
let ce = check_expression cf tenv in
match expr with
| TernaryIf (e1, e2, e3) ->
let pe = ce e1 in
let te = ce e2 in
let fe = ce e3 in
check_ternary_if loc pe te fe
| BinOp (e1, op, e2) ->
let le = ce e1 in
let re = ce e2 in
let binop_type_warnings x y =
match (x.emeta.type_, y.emeta.type_, op) with
| UInt, UInt, Divide ->
let hint ppf () =
match (x.expr, y.expr) with
| IntNumeral x, _ ->
Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression y
| _, Ast.IntNumeral y ->
Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x y
| _ ->
Fmt.pf ppf "%a * 1.0 / %a" Pretty_printing.pp_typed_expression
x Pretty_printing.pp_typed_expression y in
let s =
Fmt.str
"@[<v>@[<hov 0>Found int division:@]@ @[<hov 2>%a@]@,\
@[<hov>%a@]@ @[<hov 2>%a@]@,\
@[<hov>%a@]@]"
Pretty_printing.pp_expression {expr; emeta} Fmt.text
"Values will be rounded towards zero. If rounding is not \
desired you can write the division as"
hint () Fmt.text
"If rounding is intended please use the integer division \
operator %/%." in
add_warning x.emeta.loc s
| (UArray UMatrix | UMatrix), (UInt | UReal), Pow ->
let s =
Fmt.str
"@[<v>@[<hov 0>Found matrix^scalar:@]@ @[<hov 2>%a@]@,\
@[<hov>%a@]@ @[<hov>%a@]@]" Pretty_printing.pp_expression
{expr; emeta} Fmt.text
"matrix ^ number is interpreted as element-wise \
exponentiation. If this is intended, you can silence this \
warning by using elementwise operator .^"
Fmt.text
"If you intended matrix exponentiation, use the function \
matrix_power(matrix,int) instead." in
add_warning x.emeta.loc s
| _ when Operator.is_cmp op -> (
match le.expr with
| BinOp (e1, op2, e2) when Operator.is_cmp op2 ->
let pp_e = Pretty_printing.pp_typed_expression in
let pp = Operator.pp in
add_warning loc
(Fmt.str
"Found %a. This is interpreted as %a. Consider if the \
intended meaning was %a instead.@ You can silence this \
warning by adding explicit parenthesis. This can be \
automatically changed using the canonicalize flag for \
stanc"
(fun ppf () ->
Fmt.pf ppf "@[<hov>%a %a %a@]" pp_e le pp op2 pp_e re)
()
(fun ppf () ->
Fmt.pf ppf "@[<hov>(%a) %a %a@]" pp_e le pp op2 pp_e re)
()
(fun ppf () ->
Fmt.pf ppf "@[<hov>%a %a %a && %a %a %a@]" pp_e e1 pp op
pp_e e2 pp_e e2 pp op2 pp_e re)
())
| _ -> ())
| _ -> () in
binop_type_warnings le re;
check_binop loc op le re
| PrefixOp (op, e) -> ce e |> check_prefixop loc op
| PostfixOp (e, op) -> ce e |> check_postfixop loc op
| Variable id ->
verify_identifier id;
check_variable cf loc tenv id
| IntNumeral s -> (
match float_of_string_opt s with
| Some i when i < 2_147_483_648.0 ->
mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly
~type_:UInt ~loc
| _ -> Semantic_error.bad_int_literal loc |> error)
| RealNumeral s ->
mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal
~loc
| ImagNumeral s ->
mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly
~type_:UComplex ~loc
| GetTarget ->
(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *)
if
not
(in_lp_function cf || cf.current_block = Model
|| cf.current_block = TParam)
then
Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
else
mk_typed_expression ~expr:GetTarget
~ad_level:(calculate_autodifftype cf cf.current_block UReal)
~type_:UReal ~loc
| ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc
| RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc
| Paren e ->
let te = ce e in
mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level
~type_:te.emeta.type_ ~loc
| Indexed (e, indices) -> check_indexed loc cf tenv e indices
| TupleProjection (e, i) -> (
let te = ce e in
match (te.emeta.type_, te.emeta.ad_level) with
| UTuple ts, TupleAD ads -> (
match (List.nth ts (i - 1), List.nth ads (i - 1)) with
| Some t, Some ad ->
mk_typed_expression
~expr:(TupleProjection (te, i))
~ad_level:ad ~type_:t ~loc:emeta.loc
| None, None ->
Semantic_error.tuple_index_invalid_index emeta.loc
(List.length ts) i
|> error
| _ ->
Common.ICE.internal_compiler_error
[%message
"Error in internal representation: tuple types don't match AD"]
)
| UTuple _, ad ->
Common.ICE.internal_compiler_error
[%message
"Error in internal representation: tuple doesn't have tupleAD"
(ad : UnsizedType.autodifftype)]
| _, _ ->
Semantic_error.tuple_index_not_tuple emeta.loc te.emeta.type_ |> error
)
| TupleExpr es ->
let tes = List.map ~f:ce es in
if List.is_empty tes then Semantic_error.empty_tuple emeta.loc |> error
else
mk_typed_expression ~expr:(TupleExpr tes)
~ad_level:(TupleAD (List.map ~f:(fun e -> e.emeta.ad_level) tes))
~type_:(UTuple (List.map ~f:(fun e -> e.emeta.type_) tes))
~loc:emeta.loc
| FunApp ((), id, es) ->
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id
| CondDistApp ((), id, es) ->
es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id
| Promotion (e, _, _) ->
(* Should never happen: promotions are produced during typechecking *)
Common.ICE.internal_compiler_error
[%message "Promotion in untyped AST" (e : Ast.untyped_expression)]
and check_expression_of_int_type cf tenv e name =
let te = check_expression cf tenv e in
if has_int_type te then te
else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error
let check_expression_of_int_or_real_type cf tenv e name =
let te = check_expression cf tenv e in
if has_int_or_real_type te then te
else
Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_
|> error
let check_expression_of_scalar_or_type cf tenv t e name =
let te = check_expression cf tenv e in
if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te
else
Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_
|> error
(* -- Statements ------------------------------------------------- *)
(* non returning functions *)
let verify_nrfn_target loc cf id =
if
String.is_suffix id.name ~suffix:"_lp"
&& not
(in_lp_function cf || cf.current_block = Model
|| cf.current_block = TParam)
then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
let check_nrfn loc tenv id es =
match Env.find tenv id.name with
| {kind= `Variable _; _} :: _
(* variables can shadow stanlib functions, so we have to check this *)
when not (Stan_math_signatures.is_stan_math_function_name id.name) ->
Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error
| [] ->
Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name
(Env.nearest_ident tenv id.name)
|> error
| _ (* a function *) -> (
match
SignatureMismatch.matching_function tenv id.name (get_arg_types es)
with
| UniqueMatch (Void, fnk, promotions) ->
mk_typed_statement
~stmt:
(NRFunApp
( fnk (Fun_kind.suffix_from_name id.name)
, id
, Promotion.promote_list es promotions ))
~return_type:Incomplete ~loc
| UniqueMatch (ReturnType _, _, _) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| AmbiguousMatch sigs ->
Semantic_error.ambiguous_function_promotion loc id.name
(Some (List.map ~f:type_of_expr_typed es))
sigs
|> error
| SignatureErrors (l, b) ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
|> error)
let check_nr_fn_app loc cf tenv id es =
let tes = List.map ~f:(check_expression cf tenv) es in
verify_identifier id;
verify_nrfn_target loc cf id;
check_nrfn loc tenv id tes
(* target plus-equals / jacobian plus-equals *)
let verify_target_pe_expr_type loc e =
if UnsizedType.is_fun_type e.emeta.type_ then
Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error
let verify_target_pe_usage loc cf =
if in_lp_function cf || cf.current_block = Model then ()
else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error
let check_target_pe loc cf tenv e =
let te = check_expression cf tenv e in
verify_target_pe_usage loc cf;
verify_target_pe_expr_type loc te;
mk_typed_statement ~stmt:(TargetPE te) ~return_type:Incomplete ~loc
let verify_jacobian_pe_usage loc cf =
if in_jacobian_function cf || cf.current_block = TParam then ()
else Semantic_error.jacobian_plusequals_not_allowed loc |> error
let check_jacobian_pe loc cf tenv e =
let te = check_expression cf tenv e in
verify_jacobian_pe_usage loc cf;
verify_target_pe_expr_type loc te;
mk_typed_statement ~stmt:(JacobianPE te) ~return_type:Incomplete ~loc
(* assignments *)
let verify_assignment_read_only loc is_readonly id =
if is_readonly then
Semantic_error.cannot_assign_to_read_only loc id.name |> error
(* Variables from previous blocks are read-only.
In particular, data and parameters never assigned to
*)
let verify_assignment_global loc cf block is_global id =