From 3605768bb04fd1d3c7473364ffbf07bb34f50046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Suzanne=20Dup=C3=A9ron?= Date: Wed, 11 Dec 2019 17:54:00 +0100 Subject: [PATCH 1/4] Started auto-generation of folds on ADTs (part of the code is generated, not all) --- src/stages/adt_generator/README | 7 +++ src/stages/adt_generator/a.ml | 16 +++++ src/stages/adt_generator/adt_generator.ml | 2 + src/stages/adt_generator/dune | 6 ++ src/stages/adt_generator/fold.ml | 70 ++++++++++++++++++++++ src/stages/adt_generator/generator.py | 71 +++++++++++++++++++++++ src/stages/adt_generator/use_a_fold.ml | 18 ++++++ 7 files changed, 190 insertions(+) create mode 100644 src/stages/adt_generator/README create mode 100644 src/stages/adt_generator/a.ml create mode 100644 src/stages/adt_generator/adt_generator.ml create mode 100644 src/stages/adt_generator/dune create mode 100644 src/stages/adt_generator/fold.ml create mode 100644 src/stages/adt_generator/generator.py create mode 100644 src/stages/adt_generator/use_a_fold.ml diff --git a/src/stages/adt_generator/README b/src/stages/adt_generator/README new file mode 100644 index 000000000..20ecdfd43 --- /dev/null +++ b/src/stages/adt_generator/README @@ -0,0 +1,7 @@ +Build with: + + dune build adt_generator.a + +Run with + + python ./generator.py diff --git a/src/stages/adt_generator/a.ml b/src/stages/adt_generator/a.ml new file mode 100644 index 000000000..c30254e99 --- /dev/null +++ b/src/stages/adt_generator/a.ml @@ -0,0 +1,16 @@ +type root = +| A of a +| B of int +| C of string + +and a = { + a1 : ta1 ; + a2 : ta2 ; +} + +and ta1 = +| X of root +| Y of ta2 + +and ta2 = +| Z of ta2 diff --git a/src/stages/adt_generator/adt_generator.ml b/src/stages/adt_generator/adt_generator.ml new file mode 100644 index 000000000..9c1ff4b88 --- /dev/null +++ b/src/stages/adt_generator/adt_generator.ml @@ -0,0 +1,2 @@ +module A = A +module Use_a_fold = Use_a_fold diff --git a/src/stages/adt_generator/dune b/src/stages/adt_generator/dune new file mode 100644 index 000000000..0e1a15f71 --- /dev/null +++ b/src/stages/adt_generator/dune @@ -0,0 +1,6 @@ +(library + (name adt_generator) + (public_name ligo.adt_generator) + (libraries + ) +) diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml new file mode 100644 index 000000000..f46ca6b5e --- /dev/null +++ b/src/stages/adt_generator/fold.ml @@ -0,0 +1,70 @@ +open A + +type root' = +| A' of a' +| B' of int +| C' of string + +and a' = { + a1' : ta1' ; + a2' : ta2' ; +} + +and ta1' = +| X' of root' +| Y' of ta2' + +and ta2' = +| Z' of ta2' + +type 'state continue_fold = { + a : a -> 'state -> (a' * 'state) ; + ta1 : ta1 -> 'state -> (ta1' * 'state) ; + ta2 : ta2 -> 'state -> (ta2' * 'state) ; + root : root -> 'state -> (root' * 'state) ; +} + +type 'state fold_config = { + root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + root_b : int -> 'state -> ('state continue_fold) -> (int * 'state) ; + root_c : string -> 'state -> ('state continue_fold) -> (string * 'state) ; + a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + } + +let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor -> + { + a = fold_a visitor ; + ta1 = fold_ta1 visitor ; + ta2 = fold_ta2 visitor ; + root = fold_root visitor ; + } + +and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + match x with + | { a1; a2 } -> + let (a1', state) = visitor.ta1 a1 state continue_fold in + let (a2', state) = visitor.ta2 a2 state continue_fold in + ({ a1'; a2' }, state) + +and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + match x with + | Z v -> let (v, state) = visitor.ta2 v state continue_fold in (Z' v, state) + +and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + match x with + | X v -> let (v, state) = visitor.root v state continue_fold in (X' v , state) + | Y v -> let (v, state) = visitor.ta2 v state continue_fold in (Y' v , state) + +and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + match x with + | A v -> let (v, state) = visitor.a v state continue_fold in (A' v , state) + | B v -> let (v, state) = visitor.root_b v state continue_fold in (B' v , state) + | C v -> let (v, state) = visitor.root_c v state continue_fold in (C' v , state) +let no_op = failwith "todo" diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py new file mode 100644 index 000000000..ad537f4d7 --- /dev/null +++ b/src/stages/adt_generator/generator.py @@ -0,0 +1,71 @@ +adts = [ + # typename, variant?, fields_or_ctors + ("root", True, [ + # ctor, builtin, type + ("A", False, "a"), + ("B", True, "int"), + ("C", True, "string"), + ]), + ("a", False, [ + ("a1", False, "ta1"), + ("a2", False, "ta2"), + ]), + ("ta1", True, [ + ("X", False, "root"), + ("Y", False, "ta2"), + ]), + ("ta2", True, [ + ("Z", False, "ta2"), + ]), +] + +print "type 'state fold_config = {" +for (t, is_variant, ctors) in adts: + tt = ("%s'" % (t,)) # output type t' + print (" %s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, t, tt)) + for (c, builtin, ct,) in ctors: + if builtin: + ctt = ct # TODO: use a wrapper instead of a' for the intermediate steps, and target a different type a' just to change what the output type is + else: + ctt = ("%s'" % (ct,)) + print (" %s_%s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, c, ct, ctt)) +print " }" +print "" + +print "let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->" +print " {" +for (t, is_variant, ctors) in adts: + print (" %s = fold_%s visitor ;" % (t, t)) +print " }" +print "" + +for (t, is_variant, ctors) in adts: + v = t # visitor field + tt = ("%s'" % (t,)) # output type t' + print ("and fold_%s : type state . state fold_config -> %s -> state -> (%s * state) = fun visitor x state ->" % (t, t, tt,)) + print " let continue_fold : state continue_fold = mk_continue_fold visitor in" + print " match x with" + if is_variant: + for (c, builtin, ct,) in ctors: + cc = ("%s'" % (c,)) + print (" | %s v ->" % (c,)) + print (" let (v, state) = visitor.%s_%s v state continue_fold in" % (t, c,)) + if not builtin: + print (" let (v, state) = visitor.%s v state continue_fold in" % (ct,)) + print (" (%s v, state)" % (cc,)) + else: + print " | {" + for (f, builtin, ft,) in ctors: + print (" %s;" % (f,)) + print " } ->" + for (f, builtin, ft,) in ctors: + ff = ("%s'" % (f,)) + print (" let (%s, state) = visitor.%s_%s %s state continue_fold in" % (f, t, f, f,)) + if not builtin: + print (" let (%s, state) = visitor.%s %s state continue_fold in" % (ff, ft, f,)) + print " ({" + for (f, builtin, ft,) in ctors: + ff = ("%s'" % (f,)) + print (" %s;" % (ff,)) + print " }, state)" + print "" diff --git a/src/stages/adt_generator/use_a_fold.ml b/src/stages/adt_generator/use_a_fold.ml new file mode 100644 index 000000000..03f4cca3e --- /dev/null +++ b/src/stages/adt_generator/use_a_fold.ml @@ -0,0 +1,18 @@ +open A +open Fold + +let _ = + let some_root = ((failwith "assume we have some root") : root) in + let op = { + no_op with + a = fun the_a state continue_fold -> + let (a1' , state') = continue_fold.ta1 the_a.a1 state in + let (a2' , state'') = continue_fold.ta2 the_a.a2 state' in + ({ + a1' = a1' ; + a2' = a2' ; + }, state'') + } in + let state = () in + fold_root op some_root state + From 8b98898dbf34a508d2121f7972300325fe8b220d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Suzanne=20Dup=C3=A9ron?= Date: Thu, 12 Dec 2019 12:12:34 +0100 Subject: [PATCH 2/4] first executable version of the auto-generation of folds --- src/stages/adt_generator/a.ml | 1 + src/stages/adt_generator/dune | 14 +- src/stages/adt_generator/fold.ml | 196 ++++++++++++++++------- src/stages/adt_generator/generator.py | 207 +++++++++++++++++++------ src/stages/adt_generator/use_a_fold.ml | 12 +- 5 files changed, 323 insertions(+), 107 deletions(-) diff --git a/src/stages/adt_generator/a.ml b/src/stages/adt_generator/a.ml index c30254e99..f1d8b2fb1 100644 --- a/src/stages/adt_generator/a.ml +++ b/src/stages/adt_generator/a.ml @@ -14,3 +14,4 @@ and ta1 = and ta2 = | Z of ta2 +| W of unit diff --git a/src/stages/adt_generator/dune b/src/stages/adt_generator/dune index 0e1a15f71..4a52c6088 100644 --- a/src/stages/adt_generator/dune +++ b/src/stages/adt_generator/dune @@ -1,4 +1,16 @@ -(library +(rule + (target fold.ml) + (deps generator.py) + (action (with-stdout-to fold.ml (run python3 ./generator.py))) + (mode (promote (until-clean)))) +; (library +; (name adt_generator) +; (public_name ligo.adt_generator) +; (libraries +; ) +; ) + +(executable (name adt_generator) (public_name ligo.adt_generator) (libraries diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml index f46ca6b5e..72d64e0a4 100644 --- a/src/stages/adt_generator/fold.ml +++ b/src/stages/adt_generator/fold.ml @@ -1,70 +1,156 @@ open A type root' = -| A' of a' -| B' of int -| C' of string - -and a' = { - a1' : ta1' ; - a2' : ta2' ; -} - + | A' of a' + | B' of int + | C' of string +and a' = + { + a1' : ta1' ; + a2' : ta2' ; + } and ta1' = -| X' of root' -| Y' of ta2' - + | X' of root' + | Y' of ta2' and ta2' = -| Z' of ta2' + | Z' of ta2' + | W' of unit -type 'state continue_fold = { - a : a -> 'state -> (a' * 'state) ; - ta1 : ta1 -> 'state -> (ta1' * 'state) ; - ta2 : ta2 -> 'state -> (ta2' * 'state) ; +type 'state continue_fold = + { root : root -> 'state -> (root' * 'state) ; -} - -type 'state fold_config = { - root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; - root_a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; - root_b : int -> 'state -> ('state continue_fold) -> (int * 'state) ; - root_c : string -> 'state -> ('state continue_fold) -> (string * 'state) ; - a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; - ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; - ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + root_A : a -> 'state -> (a' * 'state) ; + root_B : int -> 'state -> (int * 'state) ; + root_C : string -> 'state -> (string * 'state) ; + a : a -> 'state -> (a' * 'state) ; + a_a1 : ta1 -> 'state -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> (ta1' * 'state) ; + ta1_X : root -> 'state -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> (ta2' * 'state) ; + ta2_Z : ta2 -> 'state -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> (unit * 'state) ; } +type 'state fold_config = + { + root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ; + root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ; + a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ; + } + +(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *) let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor -> { - a = fold_a visitor ; - ta1 = fold_ta1 visitor ; - ta2 = fold_ta2 visitor ; root = fold_root visitor ; - } - -and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | { a1; a2 } -> - let (a1', state) = visitor.ta1 a1 state continue_fold in - let (a2', state) = visitor.ta2 a2 state continue_fold in - ({ a1'; a2' }, state) - -and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | Z v -> let (v, state) = visitor.ta2 v state continue_fold in (Z' v, state) - -and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | X v -> let (v, state) = visitor.root v state continue_fold in (X' v , state) - | Y v -> let (v, state) = visitor.ta2 v state continue_fold in (Y' v , state) + root_A = fold_root_A visitor ; + root_B = fold_root_B visitor ; + root_C = fold_root_C visitor ; + a = fold_a visitor ; + a_a1 = fold_a_a1 visitor ; + a_a2 = fold_a_a2 visitor ; + ta1 = fold_ta1 visitor ; + ta1_X = fold_ta1_X visitor ; + ta1_Y = fold_ta1_Y visitor ; + ta2 = fold_ta2 visitor ; + ta2_Z = fold_ta2_Z visitor ; + ta2_W = fold_ta2_W visitor ; +} and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | A v -> let (v, state) = visitor.a v state continue_fold in (A' v , state) - | B v -> let (v, state) = visitor.root_b v state continue_fold in (B' v , state) - | C v -> let (v, state) = visitor.root_c v state continue_fold in (C' v , state) -let no_op = failwith "todo" + visitor.root x state continue_fold + +and fold_root_A : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_A x state continue_fold + +and fold_root_B : type state . state fold_config -> int -> state -> (int * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_B x state continue_fold + +and fold_root_C : type state . state fold_config -> string -> state -> (string * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_C x state continue_fold + +and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a x state continue_fold + +and fold_a_a1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a1 x state continue_fold + +and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a2 x state continue_fold + +and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1 x state continue_fold + +and fold_ta1_X : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_X x state continue_fold + +and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_Y x state continue_fold + +and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2 x state continue_fold + +and fold_ta2_Z : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_Z x state continue_fold + +and fold_ta2_W : type state . state fold_config -> unit -> state -> (unit * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_W x state continue_fold + +let no_op : 'a fold_config = { + root = (fun v state continue -> + match v with + | A v -> let (v, state) = continue.root_A v state in (A' v, state) + | B v -> let (v, state) = continue.root_B v state in (B' v, state) + | C v -> let (v, state) = continue.root_C v state in (C' v, state) + ); + root_A = (fun v state continue -> continue.a v state ) ; + root_B = (fun v state continue -> ignore continue; (v, state) ) ; + root_C = (fun v state continue -> ignore continue; (v, state) ) ; + a = (fun v state continue -> + match v with + { a1; a2; } -> + let (a1', state) = continue.a_a1 a1 state in + let (a2', state) = continue.a_a2 a2 state in + ({ a1'; a2'; }, state) + ); + a_a1 = (fun v state continue -> continue.ta1 v state ) ; + a_a2 = (fun v state continue -> continue.ta2 v state ) ; + ta1 = (fun v state continue -> + match v with + | X v -> let (v, state) = continue.ta1_X v state in (X' v, state) + | Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state) + ); + ta1_X = (fun v state continue -> continue.root v state ) ; + ta1_Y = (fun v state continue -> continue.ta2 v state ) ; + ta2 = (fun v state continue -> + match v with + | Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state) + | W v -> let (v, state) = continue.ta2_W v state in (W' v, state) + ); + ta2_Z = (fun v state continue -> continue.ta2 v state ) ; + ta2_W = (fun v state continue -> ignore continue; (v, state) ) ; +} diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py index ad537f4d7..2c8ea5dbf 100644 --- a/src/stages/adt_generator/generator.py +++ b/src/stages/adt_generator/generator.py @@ -1,3 +1,4 @@ +moduleName = "A" adts = [ # typename, variant?, fields_or_ctors ("root", True, [ @@ -16,56 +17,168 @@ adts = [ ]), ("ta2", True, [ ("Z", False, "ta2"), + ("W", True, "unit"), ]), ] -print "type 'state fold_config = {" -for (t, is_variant, ctors) in adts: - tt = ("%s'" % (t,)) # output type t' - print (" %s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, t, tt)) - for (c, builtin, ct,) in ctors: - if builtin: - ctt = ct # TODO: use a wrapper instead of a' for the intermediate steps, and target a different type a' just to change what the output type is - else: - ctt = ("%s'" % (ct,)) - print (" %s_%s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, c, ct, ctt)) -print " }" -print "" +from collections import namedtuple +adt = namedtuple('adt', ['name', 'newName', 'isVariant', 'ctorsOrFields']) +ctorOrField = namedtuple('ctorOrField', ['name', 'newName', 'isBuiltin', 'type_', 'newType']) +adts = [ + adt( + name = name, + newName = f"{name}'", + isVariant = isVariant, + ctorsOrFields = [ + ctorOrField( + name = cf, + newName = f"{cf}'", + isBuiltin = isBuiltin, + type_ = type_, + newType = type_ if isBuiltin else f"{type_}'", + ) + for (cf, isBuiltin, type_) in ctors + ], + ) + for (name, isVariant, ctors) in adts +] -print "let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->" -print " {" -for (t, is_variant, ctors) in adts: - print (" %s = fold_%s visitor ;" % (t, t)) -print " }" -print "" +print("open %s" % moduleName) -for (t, is_variant, ctors) in adts: - v = t # visitor field - tt = ("%s'" % (t,)) # output type t' - print ("and fold_%s : type state . state fold_config -> %s -> state -> (%s * state) = fun visitor x state ->" % (t, t, tt,)) - print " let continue_fold : state continue_fold = mk_continue_fold visitor in" - print " match x with" - if is_variant: - for (c, builtin, ct,) in ctors: - cc = ("%s'" % (c,)) - print (" | %s v ->" % (c,)) - print (" let (v, state) = visitor.%s_%s v state continue_fold in" % (t, c,)) - if not builtin: - print (" let (v, state) = visitor.%s v state continue_fold in" % (ct,)) - print (" (%s v, state)" % (cc,)) +print("") +for (index, t) in enumerate(adts): + typeOrAnd = "type" if index == 0 else "and" + print(f"{typeOrAnd} {t.newName} =") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.newName} of {c.newType}") else: - print " | {" - for (f, builtin, ft,) in ctors: - print (" %s;" % (f,)) - print " } ->" - for (f, builtin, ft,) in ctors: - ff = ("%s'" % (f,)) - print (" let (%s, state) = visitor.%s_%s %s state continue_fold in" % (f, t, f, f,)) - if not builtin: - print (" let (%s, state) = visitor.%s %s state continue_fold in" % (ff, ft, f,)) - print " ({" - for (f, builtin, ft,) in ctors: - ff = ("%s'" % (f,)) - print (" %s;" % (ff,)) - print " }, state)" - print "" + print(" {") + for f in t.ctorsOrFields: + print(f" {f.newName} : {f.newType} ;") + print(" }") + + +# print("") +# print("type 'state continue_fold =") +# print(" {") +# for t in adts: +# print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;") +# print(" }") + +def folder(name, extraArgs): + print("") + print(f"type 'state {name} =") + print(" {") + for t in adts: + print(f" {t.name} : {t.name} -> 'state{extraArgs} -> ({t.newName} * 'state) ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state{extraArgs} -> ({c.newType} * 'state) ;") + print(" }") + +folder("continue_fold", "") +folder("fold_config", " -> ('state continue_fold)") + +print("") +print('(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)') +print("let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->") +print(" {") +for t in adts: + print(f" {t.name} = fold_{t.name} visitor ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = fold_{t.name}_{c.name} visitor ;") +print("}") +print("") + +for t in adts: + print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" visitor.{t.name} x state continue_fold") + print("") + for c in t.ctorsOrFields: + print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" visitor.{t.name}_{c.name} x state continue_fold") + print("") + + # print(" match x with") + # if t.isVariant: + # for c in t.ctorsOrFields: + # print(f" | {c.name} v ->") + # print(f" let (v', state) = visitor.{t.name}_{c.name} v state continue_fold in") + # print(f" ({c.newName} v', state)") + # else: + # print(" | {", end=' ') + # for f in t.ctorsOrFields: + # print(f"{f.name};", end=' ') + # print("} ->") + # for f in t.ctorsOrFields: + # print(f" let ({f.newName}, state) = visitor.{t.name}_{f.name} {f.name} state continue_fold in") + # print(" ({", end=' ') + # for f in t.ctorsOrFields: + # print(f"{f.newName};", end=' ') + # print("}, state)") + # print("") + # for c in t.ctorsOrFields: + # print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") + # if c.isBuiltin: + # print(" ignore visitor; (x, state)") + # else: + # print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + # print(f" visitor.{c.type_} x state continue_fold") + # print("") + +# print """let no_op : ('a -> unit) -> 'a fold_config = fun phantom -> failwith "todo" """ + +print("let no_op : 'a fold_config = {") +for t in adts: + print(f" {t.name} = (fun v state continue ->") + print(" match v with") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.name} v -> let (v, state) = continue.{t.name}_{c.name} v state in ({c.newName} v, state)") + else: + print(" {", end=' ') + for f in t.ctorsOrFields: + print(f"{f.name};", end=' ') + print("} ->") + for f in t.ctorsOrFields: + print(f" let ({f.newName}, state) = continue.{t.name}_{f.name} {f.name} state in") + print(" ({", end=' ') + for f in t.ctorsOrFields: + print(f"{f.newName};", end=' ') + print("}, state)") + print(" );") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ') + if c.isBuiltin: + print("ignore continue; (v, state)", end=' ') + else: + print(f"continue.{c.type_} v state", end=' ') + print(") ;") +print("}") + + + # (fun v state continue -> + # let (new_v, new_state) = match v with + # | A v -> let (v, state) = continue.a v state in (A' v, state) + # | B v -> let (v, state) = (fun x s -> (x,s)) v state in (B' v, state) + # | C v -> let (v, state) = (fun x s -> (x,s)) v state in (C' v, state) + # in + # (new_v, new_state) + # ); + + + + + + + # if not builtin: + # print (" let (v', state) = match v' with None -> visitor.%s v state continue_fold | Some v' -> (v', state) in" % (ct,)) + # else: + # print " let Some v' = v' in" + + # if not builtin: + # print (" let (%s, state) = match %s with None -> visitor.%s %s state continue_fold | Some v' -> (v', state) in" % (ff, ff, ft, f)) + # else: + # print " let Some v' = v' in" diff --git a/src/stages/adt_generator/use_a_fold.ml b/src/stages/adt_generator/use_a_fold.ml index 03f4cca3e..13f78e040 100644 --- a/src/stages/adt_generator/use_a_fold.ml +++ b/src/stages/adt_generator/use_a_fold.ml @@ -2,7 +2,7 @@ open A open Fold let _ = - let some_root = ((failwith "assume we have some root") : root) in + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in let op = { no_op with a = fun the_a state continue_fold -> @@ -11,8 +11,12 @@ let _ = ({ a1' = a1' ; a2' = a2' ; - }, state'') + }, state'' + 1) } in - let state = () in - fold_root op some_root state + let state = 0 in + let (_, state) = fold_root op some_root state in + Printf.printf "trilili %d" state + +let _noi : int fold_config = no_op (* (fun _ -> ()) *) +let _nob : bool fold_config = no_op (* (fun _ -> ()) *) From fe5f8d9f64509674eafc2a530d839e5ec12cbb87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Suzanne=20Dup=C3=A9ron?= Date: Mon, 6 Jan 2020 12:58:43 +0100 Subject: [PATCH 3/4] fold_config hook to update the state after a node has been transformed, without transforming it. --- src/stages/adt_generator/fold.ml | 24 +++++-- src/stages/adt_generator/generator.py | 95 ++++++--------------------- 2 files changed, 41 insertions(+), 78 deletions(-) diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml index 72d64e0a4..3d04a1e13 100644 --- a/src/stages/adt_generator/fold.ml +++ b/src/stages/adt_generator/fold.ml @@ -36,16 +36,20 @@ type 'state continue_fold = type 'state fold_config = { root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_post_state : root -> root' -> 'state -> 'state ; root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ; root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ; a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + a_post_state : a -> a' -> 'state -> 'state ; a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta1_post_state : ta1 -> ta1' -> 'state -> 'state ; ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_post_state : ta2 -> ta2' -> 'state -> 'state ; ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ; } @@ -70,7 +74,9 @@ let rec mk_continue_fold : type state . state fold_config -> state continue_fold and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - visitor.root x state continue_fold + let (new_x, state) = visitor.root x state continue_fold in + let state = visitor.root_post_state x new_x state in + (new_x, state) and fold_root_A : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in @@ -86,7 +92,9 @@ and fold_root_C : type state . state fold_config -> string -> state -> (string * and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - visitor.a x state continue_fold + let (new_x, state) = visitor.a x state continue_fold in + let state = visitor.a_post_state x new_x state in + (new_x, state) and fold_a_a1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in @@ -98,7 +106,9 @@ and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - visitor.ta1 x state continue_fold + let (new_x, state) = visitor.ta1 x state continue_fold in + let state = visitor.ta1_post_state x new_x state in + (new_x, state) and fold_ta1_X : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in @@ -110,7 +120,9 @@ and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - visitor.ta2 x state continue_fold + let (new_x, state) = visitor.ta2 x state continue_fold in + let state = visitor.ta2_post_state x new_x state in + (new_x, state) and fold_ta2_Z : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in @@ -127,6 +139,7 @@ let no_op : 'a fold_config = { | B v -> let (v, state) = continue.root_B v state in (B' v, state) | C v -> let (v, state) = continue.root_C v state in (C' v, state) ); + root_post_state = (fun v new_v state -> ignore (v, new_v); state) ; root_A = (fun v state continue -> continue.a v state ) ; root_B = (fun v state continue -> ignore continue; (v, state) ) ; root_C = (fun v state continue -> ignore continue; (v, state) ) ; @@ -137,6 +150,7 @@ let no_op : 'a fold_config = { let (a2', state) = continue.a_a2 a2 state in ({ a1'; a2'; }, state) ); + a_post_state = (fun v new_v state -> ignore (v, new_v); state) ; a_a1 = (fun v state continue -> continue.ta1 v state ) ; a_a2 = (fun v state continue -> continue.ta2 v state ) ; ta1 = (fun v state continue -> @@ -144,6 +158,7 @@ let no_op : 'a fold_config = { | X v -> let (v, state) = continue.ta1_X v state in (X' v, state) | Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state) ); + ta1_post_state = (fun v new_v state -> ignore (v, new_v); state) ; ta1_X = (fun v state continue -> continue.root v state ) ; ta1_Y = (fun v state continue -> continue.ta2 v state ) ; ta2 = (fun v state continue -> @@ -151,6 +166,7 @@ let no_op : 'a fold_config = { | Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state) | W v -> let (v, state) = continue.ta2_W v state in (W' v, state) ); + ta2_post_state = (fun v new_v state -> ignore (v, new_v); state) ; ta2_Z = (fun v state continue -> continue.ta2 v state ) ; ta2_W = (fun v state continue -> ignore continue; (v, state) ) ; } diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py index 2c8ea5dbf..e69a1fbf0 100644 --- a/src/stages/adt_generator/generator.py +++ b/src/stages/adt_generator/generator.py @@ -58,26 +58,24 @@ for (index, t) in enumerate(adts): print(f" {f.newName} : {f.newType} ;") print(" }") +print("") +print(f"type 'state continue_fold =") +print(" {") +for t in adts: + print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ({c.newType} * 'state) ;") +print(" }") -# print("") -# print("type 'state continue_fold =") -# print(" {") -# for t in adts: -# print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;") -# print(" }") - -def folder(name, extraArgs): - print("") - print(f"type 'state {name} =") - print(" {") - for t in adts: - print(f" {t.name} : {t.name} -> 'state{extraArgs} -> ({t.newName} * 'state) ;") - for c in t.ctorsOrFields: - print(f" {t.name}_{c.name} : {c.type_} -> 'state{extraArgs} -> ({c.newType} * 'state) ;") - print(" }") - -folder("continue_fold", "") -folder("fold_config", " -> ('state continue_fold)") +print("") +print(f"type 'state fold_config =") +print(" {") +for t in adts: + print(f" {t.name} : {t.name} -> 'state -> ('state continue_fold) -> ({t.newName} * 'state) ;") + print(f" {t.name}_post_state : {t.name} -> {t.newName} -> 'state -> 'state ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ('state continue_fold) -> ({c.newType} * 'state) ;") +print(" }") print("") print('(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)') @@ -93,7 +91,9 @@ print("") for t in adts: print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->") print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") - print(f" visitor.{t.name} x state continue_fold") + print(f" let (new_x, state) = visitor.{t.name} x state continue_fold in") + print(f" let state = visitor.{t.name}_post_state x new_x state in") + print(" (new_x, state)") print("") for c in t.ctorsOrFields: print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") @@ -101,35 +101,6 @@ for t in adts: print(f" visitor.{t.name}_{c.name} x state continue_fold") print("") - # print(" match x with") - # if t.isVariant: - # for c in t.ctorsOrFields: - # print(f" | {c.name} v ->") - # print(f" let (v', state) = visitor.{t.name}_{c.name} v state continue_fold in") - # print(f" ({c.newName} v', state)") - # else: - # print(" | {", end=' ') - # for f in t.ctorsOrFields: - # print(f"{f.name};", end=' ') - # print("} ->") - # for f in t.ctorsOrFields: - # print(f" let ({f.newName}, state) = visitor.{t.name}_{f.name} {f.name} state continue_fold in") - # print(" ({", end=' ') - # for f in t.ctorsOrFields: - # print(f"{f.newName};", end=' ') - # print("}, state)") - # print("") - # for c in t.ctorsOrFields: - # print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") - # if c.isBuiltin: - # print(" ignore visitor; (x, state)") - # else: - # print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") - # print(f" visitor.{c.type_} x state continue_fold") - # print("") - -# print """let no_op : ('a -> unit) -> 'a fold_config = fun phantom -> failwith "todo" """ - print("let no_op : 'a fold_config = {") for t in adts: print(f" {t.name} = (fun v state continue ->") @@ -149,6 +120,7 @@ for t in adts: print(f"{f.newName};", end=' ') print("}, state)") print(" );") + print(f" {t.name}_post_state = (fun v new_v state -> ignore (v, new_v); state) ;") for c in t.ctorsOrFields: print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ') if c.isBuiltin: @@ -157,28 +129,3 @@ for t in adts: print(f"continue.{c.type_} v state", end=' ') print(") ;") print("}") - - - # (fun v state continue -> - # let (new_v, new_state) = match v with - # | A v -> let (v, state) = continue.a v state in (A' v, state) - # | B v -> let (v, state) = (fun x s -> (x,s)) v state in (B' v, state) - # | C v -> let (v, state) = (fun x s -> (x,s)) v state in (C' v, state) - # in - # (new_v, new_state) - # ); - - - - - - - # if not builtin: - # print (" let (v', state) = match v' with None -> visitor.%s v state continue_fold | Some v' -> (v', state) in" % (ct,)) - # else: - # print " let Some v' = v' in" - - # if not builtin: - # print (" let (%s, state) = match %s with None -> visitor.%s %s state continue_fold | Some v' -> (v', state) in" % (ff, ff, ft, f)) - # else: - # print " let Some v' = v' in" From 801efeed462e01f9328e1529314406aaabe3a0f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Suzanne=20Dup=C3=A9ron?= Date: Mon, 6 Jan 2020 13:04:38 +0100 Subject: [PATCH 4/4] tests for automatic fold generator + fold_config hook to update the state before a node has been transformed, without transforming it. --- src/stages/adt_generator/fold.ml | 12 +++++++++++ src/stages/adt_generator/generator.py | 3 +++ src/stages/adt_generator/use_a_fold.ml | 30 ++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml index 3d04a1e13..4e4c41357 100644 --- a/src/stages/adt_generator/fold.ml +++ b/src/stages/adt_generator/fold.ml @@ -36,19 +36,23 @@ type 'state continue_fold = type 'state fold_config = { root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_pre_state : root -> 'state -> 'state ; root_post_state : root -> root' -> 'state -> 'state ; root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ; root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ; a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + a_pre_state : a -> 'state -> 'state ; a_post_state : a -> a' -> 'state -> 'state ; a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta1_pre_state : ta1 -> 'state -> 'state ; ta1_post_state : ta1 -> ta1' -> 'state -> 'state ; ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_pre_state : ta2 -> 'state -> 'state ; ta2_post_state : ta2 -> ta2' -> 'state -> 'state ; ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ; @@ -74,6 +78,7 @@ let rec mk_continue_fold : type state . state fold_config -> state continue_fold and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.root_pre_state x state in let (new_x, state) = visitor.root x state continue_fold in let state = visitor.root_post_state x new_x state in (new_x, state) @@ -92,6 +97,7 @@ and fold_root_C : type state . state fold_config -> string -> state -> (string * and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.a_pre_state x state in let (new_x, state) = visitor.a x state continue_fold in let state = visitor.a_post_state x new_x state in (new_x, state) @@ -106,6 +112,7 @@ and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.ta1_pre_state x state in let (new_x, state) = visitor.ta1 x state continue_fold in let state = visitor.ta1_post_state x new_x state in (new_x, state) @@ -120,6 +127,7 @@ and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in + let state = visitor.ta2_pre_state x state in let (new_x, state) = visitor.ta2 x state continue_fold in let state = visitor.ta2_post_state x new_x state in (new_x, state) @@ -139,6 +147,7 @@ let no_op : 'a fold_config = { | B v -> let (v, state) = continue.root_B v state in (B' v, state) | C v -> let (v, state) = continue.root_C v state in (C' v, state) ); + root_pre_state = (fun v state -> ignore v; state) ; root_post_state = (fun v new_v state -> ignore (v, new_v); state) ; root_A = (fun v state continue -> continue.a v state ) ; root_B = (fun v state continue -> ignore continue; (v, state) ) ; @@ -150,6 +159,7 @@ let no_op : 'a fold_config = { let (a2', state) = continue.a_a2 a2 state in ({ a1'; a2'; }, state) ); + a_pre_state = (fun v state -> ignore v; state) ; a_post_state = (fun v new_v state -> ignore (v, new_v); state) ; a_a1 = (fun v state continue -> continue.ta1 v state ) ; a_a2 = (fun v state continue -> continue.ta2 v state ) ; @@ -158,6 +168,7 @@ let no_op : 'a fold_config = { | X v -> let (v, state) = continue.ta1_X v state in (X' v, state) | Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state) ); + ta1_pre_state = (fun v state -> ignore v; state) ; ta1_post_state = (fun v new_v state -> ignore (v, new_v); state) ; ta1_X = (fun v state continue -> continue.root v state ) ; ta1_Y = (fun v state continue -> continue.ta2 v state ) ; @@ -166,6 +177,7 @@ let no_op : 'a fold_config = { | Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state) | W v -> let (v, state) = continue.ta2_W v state in (W' v, state) ); + ta2_pre_state = (fun v state -> ignore v; state) ; ta2_post_state = (fun v new_v state -> ignore (v, new_v); state) ; ta2_Z = (fun v state continue -> continue.ta2 v state ) ; ta2_W = (fun v state continue -> ignore continue; (v, state) ) ; diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py index e69a1fbf0..65fe21878 100644 --- a/src/stages/adt_generator/generator.py +++ b/src/stages/adt_generator/generator.py @@ -72,6 +72,7 @@ print(f"type 'state fold_config =") print(" {") for t in adts: print(f" {t.name} : {t.name} -> 'state -> ('state continue_fold) -> ({t.newName} * 'state) ;") + print(f" {t.name}_pre_state : {t.name} -> 'state -> 'state ;") print(f" {t.name}_post_state : {t.name} -> {t.newName} -> 'state -> 'state ;") for c in t.ctorsOrFields: print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ('state continue_fold) -> ({c.newType} * 'state) ;") @@ -91,6 +92,7 @@ print("") for t in adts: print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->") print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" let state = visitor.{t.name}_pre_state x state in") print(f" let (new_x, state) = visitor.{t.name} x state continue_fold in") print(f" let state = visitor.{t.name}_post_state x new_x state in") print(" (new_x, state)") @@ -120,6 +122,7 @@ for t in adts: print(f"{f.newName};", end=' ') print("}, state)") print(" );") + print(f" {t.name}_pre_state = (fun v state -> ignore v; state) ;") print(f" {t.name}_post_state = (fun v new_v state -> ignore (v, new_v); state) ;") for c in t.ctorsOrFields: print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ') diff --git a/src/stages/adt_generator/use_a_fold.ml b/src/stages/adt_generator/use_a_fold.ml index 13f78e040..6a73f4782 100644 --- a/src/stages/adt_generator/use_a_fold.ml +++ b/src/stages/adt_generator/use_a_fold.ml @@ -1,7 +1,9 @@ open A open Fold -let _ = +(* TODO: how should we plug these into our test framework? *) + +let () = let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in let op = { no_op with @@ -15,8 +17,32 @@ let _ = } in let state = 0 in let (_, state) = fold_root op some_root state in - Printf.printf "trilili %d" state + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () + +let () = + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in + let op = { no_op with a_pre_state = fun _the_a state -> state + 1 } in + let state = 0 in + let (_, state) = fold_root op some_root state in + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () + +let () = + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in + let op = { no_op with a_post_state = fun _the_a _new_a state -> state + 1 } in + let state = 0 in + let (_, state) = fold_root op some_root state in + if state != 2 then + failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state) + else + () +(* Test that the same fold_config can be ascibed with different 'a type arguments *) let _noi : int fold_config = no_op (* (fun _ -> ()) *) let _nob : bool fold_config = no_op (* (fun _ -> ()) *)