diff --git a/src/passes/8-typer-old/typer.ml b/src/passes/8-typer-old/typer.ml index 8348653cf..b94a51475 100644 --- a/src/passes/8-typer-old/typer.ml +++ b/src/passes/8-typer-old/typer.ml @@ -88,6 +88,15 @@ module Errors = struct ("environment" , fun () -> Format.asprintf "%a" Environment.PP.full_environment e) ; ] in error ~data title message () + + let michelson_or (c:I.constructor') loc () = + let title = (thunk "michelson_or types must be annotated") in + let message () = "" in + let data = [ + ("constructor" , fun () -> Format.asprintf "%a" I.PP.constructor c); + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () let wrong_arity (n:string) (expected:int) (actual:int) (loc : Location.t) () = let title () = "wrong arity" in @@ -341,7 +350,10 @@ and evaluate_type (e:environment) (t:I.type_expression) : O.type_expression resu let%bind prev' = prev in let%bind v' = evaluate_type e v in let%bind () = match Environment.get_constructor k e with - | Some _ -> fail (redundant_constructor e k) + | Some _ -> + if I.CMap.mem (Constructor "M_left") m || I.CMap.mem (Constructor "M_right") m then + ok () + else fail (redundant_constructor e k) | None -> ok () in ok @@ I.CMap.add k v' prev' in @@ -477,6 +489,17 @@ and type_expression' : environment -> ?tv_opt:O.type_expression -> I.expression | None -> ok () | Some tv' -> O.assert_type_expression_eq (tv' , ae.type_expression) in ok(ae) + | E_constructor {constructor = Constructor s ; element} when String.equal s "M_left" || String.equal s "M_right" -> ( + let%bind t = trace_option (Errors.michelson_or (Constructor s) ae.location) @@ tv_opt in + let%bind expr' = type_expression' e element in + ( match t.type_content with + | T_sum c -> + let ct = I.CMap.find (I.Constructor s) c in + let%bind _assert = O.assert_type_expression_eq (expr'.type_expression, ct) in + return (E_constructor {constructor = Constructor s; element=expr'}) t + | _ -> simple_fail "ll" + ) + ) (* Sum *) | E_constructor {constructor; element} -> let%bind (c_tv, sum_tv) =