add set fold

This commit is contained in:
galfour 2019-09-24 00:26:39 +02:00
parent c4752c5935
commit 9c3c40c9ef
4 changed files with 29 additions and 1 deletions

View File

@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
| ("LIST_MAP" , lst) -> map lst | ("LIST_MAP" , lst) -> map lst
| ("MAP_MAP" , lst) -> map lst | ("MAP_MAP" , lst) -> map lst
| ("LIST_FOLD" , lst) -> fold lst | ("LIST_FOLD" , lst) -> fold lst
| ("SET_FOLD" , lst) -> fold lst
| ("MAP_FOLD" , lst) -> fold lst | ("MAP_FOLD" , lst) -> fold lst
| _ -> ( | _ -> (
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in let%bind lst' = bind_map_list (transpile_annotated_expression) lst in

View File

@ -81,6 +81,7 @@ module Simplify = struct
("set_add" , "SET_ADD") ; ("set_add" , "SET_ADD") ;
("set_remove" , "SET_REMOVE") ; ("set_remove" , "SET_REMOVE") ;
("set_iter" , "SET_ITER") ; ("set_iter" , "SET_ITER") ;
("set_fold" , "SET_FOLD") ;
("list_iter" , "LIST_ITER") ; ("list_iter" , "LIST_ITER") ;
("list_fold" , "LIST_FOLD") ; ("list_fold" , "LIST_FOLD") ;
("list_map" , "LIST_MAP") ; ("list_map" , "LIST_MAP") ;
@ -148,6 +149,7 @@ module Simplify = struct
("Set.empty" , "SET_EMPTY") ; ("Set.empty" , "SET_EMPTY") ;
("Set.add" , "SET_ADD") ; ("Set.add" , "SET_ADD") ;
("Set.remove" , "SET_REMOVE") ; ("Set.remove" , "SET_REMOVE") ;
("Set.fold" , "SET_FOLD") ;
("Map.find_opt" , "MAP_FIND_OPT") ; ("Map.find_opt" , "MAP_FIND_OPT") ;
("Map.find" , "MAP_FIND") ; ("Map.find" , "MAP_FIND") ;
@ -156,7 +158,7 @@ module Simplify = struct
("Map.remove" , "MAP_REMOVE") ; ("Map.remove" , "MAP_REMOVE") ;
("Map.iter" , "MAP_ITER") ; ("Map.iter" , "MAP_ITER") ;
("Map.map" , "MAP_MAP") ; ("Map.map" , "MAP_MAP") ;
("Map.fold" , "LIST_FOLD") ; ("Map.fold" , "MAP_FOLD") ;
("String.length", "SIZE") ; ("String.length", "SIZE") ;
("String.size", "SIZE") ; ("String.size", "SIZE") ;
@ -496,6 +498,20 @@ module Typer = struct
let%bind () = assert_eq_1 ~msg:"res init" res init in let%bind () = assert_eq_1 ~msg:"res init" res init in
ok res ok res
let set_fold = typer_3 "SET_FOLD" @@ fun lst init body ->
let%bind (arg , res) = get_t_function body in
let%bind (prec , cur) = get_t_pair arg in
let%bind key = get_t_set lst in
let msg = Format.asprintf "%a vs %a"
Ast_typed.PP.type_value key
Ast_typed.PP.type_value arg
in
trace (simple_error ("bad set fold:" ^ msg)) @@
let%bind () = assert_eq_1 ~msg:"key cur" key cur in
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
let%bind () = assert_eq_1 ~msg:"res init" res init in
ok res
let map_fold = typer_3 "MAP_FOLD" @@ fun map init body -> let map_fold = typer_3 "MAP_FOLD" @@ fun map init body ->
let%bind (arg , res) = get_t_function body in let%bind (arg , res) = get_t_function body in
let%bind (prec , cur) = get_t_pair arg in let%bind (prec , cur) = get_t_pair arg in
@ -593,6 +609,7 @@ module Typer = struct
set_add ; set_add ;
set_remove ; set_remove ;
set_iter ; set_iter ;
set_fold ;
list_iter ; list_iter ;
list_map ; list_map ;
list_fold ; list_fold ;

View File

@ -9,3 +9,8 @@ function iter_op (const s : set(int)) : int is
begin begin
set_iter(s , aggregate) ; set_iter(s , aggregate) ;
end with r end with r
function fold_op (const s : set(int)) : int is
function aggregate (const i : int ; const j : int) : int is
block { skip } with i + j
block { skip } with set_fold(s , 15 , aggregate)

View File

@ -224,6 +224,11 @@ let set_arithmetic () : unit result =
expect_eq program "mem_op" expect_eq program "mem_op"
(e_set [e_string "foo" ; e_string "bar"]) (e_set [e_string "foo" ; e_string "bar"])
(e_bool false) in (e_bool false) in
let%bind () =
expect_eq program_1 "fold_op"
(e_set [ e_int 4 ; e_int 10 ])
(e_int 29)
in
ok () ok ()
let unit_expression () : unit result = let unit_expression () : unit result =