add set fold
This commit is contained in:
parent
c4752c5935
commit
9c3c40c9ef
@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
|
||||
| ("LIST_MAP" , lst) -> map lst
|
||||
| ("MAP_MAP" , lst) -> map lst
|
||||
| ("LIST_FOLD" , lst) -> fold lst
|
||||
| ("SET_FOLD" , lst) -> fold lst
|
||||
| ("MAP_FOLD" , lst) -> fold lst
|
||||
| _ -> (
|
||||
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
|
||||
|
@ -81,6 +81,7 @@ module Simplify = struct
|
||||
("set_add" , "SET_ADD") ;
|
||||
("set_remove" , "SET_REMOVE") ;
|
||||
("set_iter" , "SET_ITER") ;
|
||||
("set_fold" , "SET_FOLD") ;
|
||||
("list_iter" , "LIST_ITER") ;
|
||||
("list_fold" , "LIST_FOLD") ;
|
||||
("list_map" , "LIST_MAP") ;
|
||||
@ -148,6 +149,7 @@ module Simplify = struct
|
||||
("Set.empty" , "SET_EMPTY") ;
|
||||
("Set.add" , "SET_ADD") ;
|
||||
("Set.remove" , "SET_REMOVE") ;
|
||||
("Set.fold" , "SET_FOLD") ;
|
||||
|
||||
("Map.find_opt" , "MAP_FIND_OPT") ;
|
||||
("Map.find" , "MAP_FIND") ;
|
||||
@ -156,7 +158,7 @@ module Simplify = struct
|
||||
("Map.remove" , "MAP_REMOVE") ;
|
||||
("Map.iter" , "MAP_ITER") ;
|
||||
("Map.map" , "MAP_MAP") ;
|
||||
("Map.fold" , "LIST_FOLD") ;
|
||||
("Map.fold" , "MAP_FOLD") ;
|
||||
|
||||
("String.length", "SIZE") ;
|
||||
("String.size", "SIZE") ;
|
||||
@ -496,6 +498,20 @@ module Typer = struct
|
||||
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
||||
ok res
|
||||
|
||||
let set_fold = typer_3 "SET_FOLD" @@ fun lst init body ->
|
||||
let%bind (arg , res) = get_t_function body in
|
||||
let%bind (prec , cur) = get_t_pair arg in
|
||||
let%bind key = get_t_set lst in
|
||||
let msg = Format.asprintf "%a vs %a"
|
||||
Ast_typed.PP.type_value key
|
||||
Ast_typed.PP.type_value arg
|
||||
in
|
||||
trace (simple_error ("bad set fold:" ^ msg)) @@
|
||||
let%bind () = assert_eq_1 ~msg:"key cur" key cur in
|
||||
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
|
||||
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
||||
ok res
|
||||
|
||||
let map_fold = typer_3 "MAP_FOLD" @@ fun map init body ->
|
||||
let%bind (arg , res) = get_t_function body in
|
||||
let%bind (prec , cur) = get_t_pair arg in
|
||||
@ -593,6 +609,7 @@ module Typer = struct
|
||||
set_add ;
|
||||
set_remove ;
|
||||
set_iter ;
|
||||
set_fold ;
|
||||
list_iter ;
|
||||
list_map ;
|
||||
list_fold ;
|
||||
|
@ -9,3 +9,8 @@ function iter_op (const s : set(int)) : int is
|
||||
begin
|
||||
set_iter(s , aggregate) ;
|
||||
end with r
|
||||
|
||||
function fold_op (const s : set(int)) : int is
|
||||
function aggregate (const i : int ; const j : int) : int is
|
||||
block { skip } with i + j
|
||||
block { skip } with set_fold(s , 15 , aggregate)
|
||||
|
@ -224,6 +224,11 @@ let set_arithmetic () : unit result =
|
||||
expect_eq program "mem_op"
|
||||
(e_set [e_string "foo" ; e_string "bar"])
|
||||
(e_bool false) in
|
||||
let%bind () =
|
||||
expect_eq program_1 "fold_op"
|
||||
(e_set [ e_int 4 ; e_int 10 ])
|
||||
(e_int 29)
|
||||
in
|
||||
ok ()
|
||||
|
||||
let unit_expression () : unit result =
|
||||
|
Loading…
Reference in New Issue
Block a user