diff --git a/src/stages/adt_generator/README b/src/stages/adt_generator/README new file mode 100644 index 000000000..20ecdfd43 --- /dev/null +++ b/src/stages/adt_generator/README @@ -0,0 +1,7 @@ +Build with: + + dune build adt_generator.a + +Run with + + python ./generator.py diff --git a/src/stages/adt_generator/a.ml b/src/stages/adt_generator/a.ml new file mode 100644 index 000000000..f1d8b2fb1 --- /dev/null +++ b/src/stages/adt_generator/a.ml @@ -0,0 +1,17 @@ +type root = +| A of a +| B of int +| C of string + +and a = { + a1 : ta1 ; + a2 : ta2 ; +} + +and ta1 = +| X of root +| Y of ta2 + +and ta2 = +| Z of ta2 +| W of unit diff --git a/src/stages/adt_generator/adt_generator.ml b/src/stages/adt_generator/adt_generator.ml new file mode 100644 index 000000000..9c1ff4b88 --- /dev/null +++ b/src/stages/adt_generator/adt_generator.ml @@ -0,0 +1,2 @@ +module A = A +module Use_a_fold = Use_a_fold diff --git a/src/stages/adt_generator/dune b/src/stages/adt_generator/dune new file mode 100644 index 000000000..4a52c6088 --- /dev/null +++ b/src/stages/adt_generator/dune @@ -0,0 +1,18 @@ +(rule + (target fold.ml) + (deps generator.py) + (action (with-stdout-to fold.ml (run python3 ./generator.py))) + (mode (promote (until-clean)))) +; (library +; (name adt_generator) +; (public_name ligo.adt_generator) +; (libraries +; ) +; ) + +(executable + (name adt_generator) + (public_name ligo.adt_generator) + (libraries + ) +) diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml new file mode 100644 index 000000000..4e4c41357 --- /dev/null +++ b/src/stages/adt_generator/fold.ml @@ -0,0 +1,184 @@ +open A + +type root' = + | A' of a' + | B' of int + | C' of string +and a' = + { + a1' : ta1' ; + a2' : ta2' ; + } +and ta1' = + | X' of root' + | Y' of ta2' +and ta2' = + | Z' of ta2' + | W' of unit + +type 'state continue_fold = + { + root : root -> 'state -> (root' * 'state) ; + root_A : a -> 'state -> (a' * 'state) ; + root_B : int -> 'state -> (int * 'state) ; + root_C : string -> 'state -> (string * 'state) ; + a : a -> 'state -> (a' * 'state) ; + a_a1 : ta1 -> 'state -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> (ta1' * 'state) ; + ta1_X : root -> 'state -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> (ta2' * 'state) ; + ta2_Z : ta2 -> 'state -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> (unit * 'state) ; + } + +type 'state fold_config = + { + root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_pre_state : root -> 'state -> 'state ; + root_post_state : root -> root' -> 'state -> 'state ; + root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ; + root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ; + a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + a_pre_state : a -> 'state -> 'state ; + a_post_state : a -> a' -> 'state -> 'state ; + a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta1_pre_state : ta1 -> 'state -> 'state ; + ta1_post_state : ta1 -> ta1' -> 'state -> 'state ; + ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_pre_state : ta2 -> 'state -> 'state ; + ta2_post_state : ta2 -> ta2' -> 'state -> 'state ; + ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ; + } + +(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *) +let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor -> + { + root = fold_root visitor ; + root_A = fold_root_A visitor ; + root_B = fold_root_B visitor ; + root_C = fold_root_C visitor ; + a = fold_a visitor ; + a_a1 = fold_a_a1 visitor ; + a_a2 = fold_a_a2 visitor ; + ta1 = fold_ta1 visitor ; + ta1_X = fold_ta1_X visitor ; + ta1_Y = fold_ta1_Y visitor ; + ta2 = fold_ta2 visitor ; + ta2_Z = fold_ta2_Z visitor ; + ta2_W = fold_ta2_W visitor ; +} + +and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.root_pre_state x state in + let (new_x, state) = visitor.root x state continue_fold in + let state = visitor.root_post_state x new_x state in + (new_x, state) + +and fold_root_A : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_A x state continue_fold + +and fold_root_B : type state . state fold_config -> int -> state -> (int * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_B x state continue_fold + +and fold_root_C : type state . state fold_config -> string -> state -> (string * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_C x state continue_fold + +and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.a_pre_state x state in + let (new_x, state) = visitor.a x state continue_fold in + let state = visitor.a_post_state x new_x state in + (new_x, state) + +and fold_a_a1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a1 x state continue_fold + +and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a2 x state continue_fold + +and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.ta1_pre_state x state in + let (new_x, state) = visitor.ta1 x state continue_fold in + let state = visitor.ta1_post_state x new_x state in + (new_x, state) + +and fold_ta1_X : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_X x state continue_fold + +and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_Y x state continue_fold + +and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.ta2_pre_state x state in + let (new_x, state) = visitor.ta2 x state continue_fold in + let state = visitor.ta2_post_state x new_x state in + (new_x, state) + +and fold_ta2_Z : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_Z x state continue_fold + +and fold_ta2_W : type state . state fold_config -> unit -> state -> (unit * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_W x state continue_fold + +let no_op : 'a fold_config = { + root = (fun v state continue -> + match v with + | A v -> let (v, state) = continue.root_A v state in (A' v, state) + | B v -> let (v, state) = continue.root_B v state in (B' v, state) + | C v -> let (v, state) = continue.root_C v state in (C' v, state) + ); + root_pre_state = (fun v state -> ignore v; state) ; + root_post_state = (fun v new_v state -> ignore (v, new_v); state) ; + root_A = (fun v state continue -> continue.a v state ) ; + root_B = (fun v state continue -> ignore continue; (v, state) ) ; + root_C = (fun v state continue -> ignore continue; (v, state) ) ; + a = (fun v state continue -> + match v with + { a1; a2; } -> + let (a1', state) = continue.a_a1 a1 state in + let (a2', state) = continue.a_a2 a2 state in + ({ a1'; a2'; }, state) + ); + a_pre_state = (fun v state -> ignore v; state) ; + a_post_state = (fun v new_v state -> ignore (v, new_v); state) ; + a_a1 = (fun v state continue -> continue.ta1 v state ) ; + a_a2 = (fun v state continue -> continue.ta2 v state ) ; + ta1 = (fun v state continue -> + match v with + | X v -> let (v, state) = continue.ta1_X v state in (X' v, state) + | Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state) + ); + ta1_pre_state = (fun v state -> ignore v; state) ; + ta1_post_state = (fun v new_v state -> ignore (v, new_v); state) ; + ta1_X = (fun v state continue -> continue.root v state ) ; + ta1_Y = (fun v state continue -> continue.ta2 v state ) ; + ta2 = (fun v state continue -> + match v with + | Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state) + | W v -> let (v, state) = continue.ta2_W v state in (W' v, state) + ); + ta2_pre_state = (fun v state -> ignore v; state) ; + ta2_post_state = (fun v new_v state -> ignore (v, new_v); state) ; + ta2_Z = (fun v state continue -> continue.ta2 v state ) ; + ta2_W = (fun v state continue -> ignore continue; (v, state) ) ; +} diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py new file mode 100644 index 000000000..65fe21878 --- /dev/null +++ b/src/stages/adt_generator/generator.py @@ -0,0 +1,134 @@ +moduleName = "A" +adts = [ + # typename, variant?, fields_or_ctors + ("root", True, [ + # ctor, builtin, type + ("A", False, "a"), + ("B", True, "int"), + ("C", True, "string"), + ]), + ("a", False, [ + ("a1", False, "ta1"), + ("a2", False, "ta2"), + ]), + ("ta1", True, [ + ("X", False, "root"), + ("Y", False, "ta2"), + ]), + ("ta2", True, [ + ("Z", False, "ta2"), + ("W", True, "unit"), + ]), +] + +from collections import namedtuple +adt = namedtuple('adt', ['name', 'newName', 'isVariant', 'ctorsOrFields']) +ctorOrField = namedtuple('ctorOrField', ['name', 'newName', 'isBuiltin', 'type_', 'newType']) +adts = [ + adt( + name = name, + newName = f"{name}'", + isVariant = isVariant, + ctorsOrFields = [ + ctorOrField( + name = cf, + newName = f"{cf}'", + isBuiltin = isBuiltin, + type_ = type_, + newType = type_ if isBuiltin else f"{type_}'", + ) + for (cf, isBuiltin, type_) in ctors + ], + ) + for (name, isVariant, ctors) in adts +] + +print("open %s" % moduleName) + +print("") +for (index, t) in enumerate(adts): + typeOrAnd = "type" if index == 0 else "and" + print(f"{typeOrAnd} {t.newName} =") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.newName} of {c.newType}") + else: + print(" {") + for f in t.ctorsOrFields: + print(f" {f.newName} : {f.newType} ;") + print(" }") + +print("") +print(f"type 'state continue_fold =") +print(" {") +for t in adts: + print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ({c.newType} * 'state) ;") +print(" }") + +print("") +print(f"type 'state fold_config =") +print(" {") +for t in adts: + print(f" {t.name} : {t.name} -> 'state -> ('state continue_fold) -> ({t.newName} * 'state) ;") + print(f" {t.name}_pre_state : {t.name} -> 'state -> 'state ;") + print(f" {t.name}_post_state : {t.name} -> {t.newName} -> 'state -> 'state ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ('state continue_fold) -> ({c.newType} * 'state) ;") +print(" }") + +print("") +print('(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)') +print("let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->") +print(" {") +for t in adts: + print(f" {t.name} = fold_{t.name} visitor ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = fold_{t.name}_{c.name} visitor ;") +print("}") +print("") + +for t in adts: + print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" let state = visitor.{t.name}_pre_state x state in") + print(f" let (new_x, state) = visitor.{t.name} x state continue_fold in") + print(f" let state = visitor.{t.name}_post_state x new_x state in") + print(" (new_x, state)") + print("") + for c in t.ctorsOrFields: + print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" visitor.{t.name}_{c.name} x state continue_fold") + print("") + +print("let no_op : 'a fold_config = {") +for t in adts: + print(f" {t.name} = (fun v state continue ->") + print(" match v with") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.name} v -> let (v, state) = continue.{t.name}_{c.name} v state in ({c.newName} v, state)") + else: + print(" {", end=' ') + for f in t.ctorsOrFields: + print(f"{f.name};", end=' ') + print("} ->") + for f in t.ctorsOrFields: + print(f" let ({f.newName}, state) = continue.{t.name}_{f.name} {f.name} state in") + print(" ({", end=' ') + for f in t.ctorsOrFields: + print(f"{f.newName};", end=' ') + print("}, state)") + print(" );") + print(f" {t.name}_pre_state = (fun v state -> ignore v; state) ;") + print(f" {t.name}_post_state = (fun v new_v state -> ignore (v, new_v); state) ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ') + if c.isBuiltin: + print("ignore continue; (v, state)", end=' ') + else: + print(f"continue.{c.type_} v state", end=' ') + print(") ;") +print("}") diff --git a/src/stages/adt_generator/use_a_fold.ml b/src/stages/adt_generator/use_a_fold.ml new file mode 100644 index 000000000..6a73f4782 --- /dev/null +++ b/src/stages/adt_generator/use_a_fold.ml @@ -0,0 +1,48 @@ +open A +open Fold + +(* TODO: how should we plug these into our test framework? *) + +let () = + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in + let op = { + no_op with + a = fun the_a state continue_fold -> + let (a1' , state') = continue_fold.ta1 the_a.a1 state in + let (a2' , state'') = continue_fold.ta2 the_a.a2 state' in + ({ + a1' = a1' ; + a2' = a2' ; + }, state'' + 1) + } in + let state = 0 in + let (_, state) = fold_root op some_root state in + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () + +let () = + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in + let op = { no_op with a_pre_state = fun _the_a state -> state + 1 } in + let state = 0 in + let (_, state) = fold_root op some_root state in + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () + +let () = + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in + let op = { no_op with a_post_state = fun _the_a _new_a state -> state + 1 } in + let state = 0 in + let (_, state) = fold_root op some_root state in + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () + + +(* Test that the same fold_config can be ascibed with different 'a type arguments *) +let _noi : int fold_config = no_op (* (fun _ -> ()) *) +let _nob : bool fold_config = no_op (* (fun _ -> ()) *)