Fix some polymorphic comparison bugs

This commit is contained in:
Tom Jack 2020-06-30 14:31:04 -05:00
parent e9db0afffa
commit ead7832c50
9 changed files with 30 additions and 13 deletions

View File

@ -3,7 +3,7 @@ module I = Ast_imperative
module O = Ast_sugar module O = Ast_sugar
open Trace open Trace
let compare_var : O.expression_variable -> O.expression_variable -> int = fun (a:O.expression_variable) (b:O.expression_variable) -> Var.compare a.wrap_content b.wrap_content let compare_var = Location.compare_content ~compare:Var.compare
let rec add_to_end (expression: O.expression) to_add = let rec add_to_end (expression: O.expression) to_add =
match expression.expression_content with match expression.expression_content with
@ -25,10 +25,10 @@ let repair_mutable_variable_in_matching (match_body : O.expression) (element_nam
ok (true,(name::decl_var, free_var),O.e_let_in let_binder false false rhs let_result) ok (true,(name::decl_var, free_var),O.e_let_in let_binder false false rhs let_result)
| E_let_in {let_binder;mut=true; rhs;let_result} -> | E_let_in {let_binder;mut=true; rhs;let_result} ->
let (name,_) = let_binder in let (name,_) = let_binder in
if List.mem name decl_var then if List.mem ~compare:compare_var name decl_var then
ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs let_result) ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs let_result)
else( else(
let free_var = if (List.mem name free_var) then free_var else name::free_var in let free_var = if (List.mem ~compare:compare_var name free_var) then free_var else name::free_var in
let expr = O.e_let_in (env,None) false false (O.e_update (O.e_variable env) [O.Access_record (Var.to_name name.wrap_content)] (O.e_variable name)) let_result in let expr = O.e_let_in (env,None) false false (O.e_update (O.e_variable env) [O.Access_record (Var.to_name name.wrap_content)] (O.e_variable name)) let_result in
ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs expr) ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs expr)
) )
@ -65,10 +65,13 @@ and repair_mutable_variable_in_loops (for_body : O.expression) (element_names :
ok (true,(name::decl_var, free_var),ass_exp) ok (true,(name::decl_var, free_var),ass_exp)
| E_let_in {let_binder;mut=true; rhs;let_result} -> | E_let_in {let_binder;mut=true; rhs;let_result} ->
let (name,_) = let_binder in let (name,_) = let_binder in
if List.mem name decl_var then if List.mem ~compare:compare_var name decl_var then
ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs let_result) ok (true,(decl_var, free_var), O.e_let_in let_binder false false rhs let_result)
else( else(
let free_var = if (List.mem name free_var) then free_var else name::free_var in let free_var =
if (List.mem ~compare:compare_var name free_var)
then free_var
else name::free_var in
let expr = O.e_let_in (env,None) false false ( let expr = O.e_let_in (env,None) false false (
O.e_update (O.e_variable env) [O.Access_tuple Z.zero; O.Access_record (Var.to_name name.wrap_content)] (O.e_variable name) O.e_update (O.e_variable env) [O.Access_tuple Z.zero; O.Access_record (Var.to_name name.wrap_content)] (O.e_variable name)
) )

View File

@ -66,7 +66,9 @@ let rec decompile_expression : O.expression -> (I.expression, desugaring_error)
let%bind fun_type = decompile_type_expression fun_type in let%bind fun_type = decompile_type_expression fun_type in
let%bind lambda = decompile_lambda lambda in let%bind lambda = decompile_lambda lambda in
return @@ I.E_recursive {fun_name;fun_type;lambda} return @@ I.E_recursive {fun_name;fun_type;lambda}
| O.E_let_in {let_binder;inline=false;rhs=expr1;let_result=expr2} when let_binder = (Location.wrap @@ Var.of_name "_", Some (O.t_unit ())) -> | O.E_let_in {let_binder = (var, ty);inline=false;rhs=expr1;let_result=expr2}
when Var.equal var.wrap_content (Var.of_name "_")
&& Pervasives.(=) ty (Some (O.t_unit ())) ->
let%bind expr1 = decompile_expression expr1 in let%bind expr1 = decompile_expression expr1 in
let%bind expr2 = decompile_expression expr2 in let%bind expr2 = decompile_expression expr2 in
return @@ I.E_sequence {expr1;expr2} return @@ I.E_sequence {expr1;expr2}

View File

@ -743,7 +743,7 @@ and type_lambda e {
match result.content with match result.content with
| I.E_let_in li -> ( | I.E_let_in li -> (
match li.rhs.content with match li.rhs.content with
| I.E_variable name when name = (binder) -> ( | I.E_variable name when Location.equal_content ~equal:Var.equal name binder -> (
match snd li.let_binder with match snd li.let_binder with
| Some ty -> ok ty | Some ty -> ok ty
| None -> default_action li.rhs () | None -> default_action li.rhs ()

View File

@ -2,6 +2,8 @@ open Errors
open Ast_typed open Ast_typed
open Trace open Trace
let var_equal = Location.equal_content ~equal:Var.equal
let rec check_recursive_call : expression_variable -> bool -> expression -> (unit, self_ast_typed_error) result = fun n final_path e -> let rec check_recursive_call : expression_variable -> bool -> expression -> (unit, self_ast_typed_error) result = fun n final_path e ->
match e.expression_content with match e.expression_content with
| E_literal _ -> ok () | E_literal _ -> ok ()
@ -10,7 +12,7 @@ let rec check_recursive_call : expression_variable -> bool -> expression -> (uni
ok () ok ()
| E_variable v -> ( | E_variable v -> (
let%bind _ = Assert.assert_true (recursive_call_is_only_allowed_as_the_last_operation n e.location) let%bind _ = Assert.assert_true (recursive_call_is_only_allowed_as_the_last_operation n e.location)
(final_path || n <> v) in (final_path || not (var_equal n v)) in
ok () ok ()
) )
| E_application {lamb;args} -> | E_application {lamb;args} ->

View File

@ -128,7 +128,7 @@ let rec subst_expression : body:expression -> x:var_name -> expr:expression -> e
let return_id = body in let return_id = body in
match body.content with match body.content with
| E_variable x' -> | E_variable x' ->
if x' = x if Location.equal_content ~equal:Var.equal x' x
then expr then expr
else return_id else return_id
| E_closure { binder; body } -> ( | E_closure { binder; body } -> (

View File

@ -3,7 +3,8 @@ open Types
module Free_variables = struct module Free_variables = struct
type bindings = expression_variable list type bindings = expression_variable list
let mem : expression_variable -> bindings -> bool = List.mem let var_compare = Location.compare_content ~compare:Var.compare
let mem : expression_variable -> bindings -> bool = List.mem ~compare:var_compare
let singleton : expression_variable -> bindings = fun s -> [ s ] let singleton : expression_variable -> bindings = fun s -> [ s ]
let union : bindings -> bindings -> bindings = (@) let union : bindings -> bindings -> bindings = (@)
let unions : bindings list -> bindings = List.concat let unions : bindings list -> bindings = List.concat

View File

@ -6,7 +6,8 @@ open Misc
module Captured_variables = struct module Captured_variables = struct
type bindings = expression_variable list type bindings = expression_variable list
let mem : expression_variable -> bindings -> bool = List.mem let var_compare = Location.compare_content ~compare:Var.compare
let mem : expression_variable -> bindings -> bool = List.mem ~compare:var_compare
let singleton : expression_variable -> bindings = fun s -> [ s ] let singleton : expression_variable -> bindings = fun s -> [ s ]
let union : bindings -> bindings -> bindings = (@) let union : bindings -> bindings -> bindings = (@)
let unions : bindings list -> bindings = List.concat let unions : bindings list -> bindings = List.concat

View File

@ -4,11 +4,13 @@ open Combinators
module Free_variables = struct module Free_variables = struct
type bindings = expression_variable list type bindings = expression_variable list
let mem : expression_variable -> bindings -> bool = List.mem let var_equal = Location.equal_content ~equal:Var.equal
let var_compare = Location.compare_content ~compare:Var.compare
let mem : expression_variable -> bindings -> bool = List.mem ~compare:var_compare
let singleton : expression_variable -> bindings = fun s -> [ s ] let singleton : expression_variable -> bindings = fun s -> [ s ]
let mem_count : expression_variable -> bindings -> int = let mem_count : expression_variable -> bindings -> int =
fun x fvs -> fun x fvs ->
List.length (List.filter (fun (a:expression_variable) -> Var.equal x.wrap_content a.wrap_content) fvs) List.length (List.filter (var_equal x) fvs)
let union : bindings -> bindings -> bindings = (@) let union : bindings -> bindings -> bindings = (@)
let unions : bindings list -> bindings = List.concat let unions : bindings list -> bindings = List.concat
let empty : bindings = [] let empty : bindings = []

View File

@ -46,6 +46,12 @@ let compare_wrap ~compare:compare_content { wrap_content = wca ; location = la }
| 0 -> compare la lb | 0 -> compare la lb
| c -> c | c -> c
let compare_content ~compare:compare_content wa wb =
compare_content wa.wrap_content wb.wrap_content
let equal_content ~equal:equal_content wa wb =
equal_content wa.wrap_content wb.wrap_content
let wrap ?(loc = generated) wrap_content = { wrap_content ; location = loc } let wrap ?(loc = generated) wrap_content = { wrap_content ; location = loc }
let get_location x = x.location let get_location x = x.location
let unwrap { wrap_content ; _ } = wrap_content let unwrap { wrap_content ; _ } = wrap_content