diff --git a/src/utils/error_monad.ml b/src/utils/error_monad.ml index 85dd780c4..17c70ec0a 100644 --- a/src/utils/error_monad.ml +++ b/src/utils/error_monad.ml @@ -299,9 +299,15 @@ module Make() = struct let fail_unless cond exn = if cond then return () else fail exn + let fail_when cond exn = + if cond then fail exn else return () + let unless cond f = if cond then return () else f () + let _when cond f = + if cond then f () else return () + let pp_print_error ppf errors = match errors with | [] -> @@ -339,6 +345,42 @@ let () = error_kinds := Error_kind { id; from_error ; category; encoding_case ; pp } :: !error_kinds +type error += Assert_error of string * string + +let () = + let id = "" in + let category = `Permanent in + let to_error (loc, msg) = Assert_error (loc, msg) in + let from_error = function + | Assert_error (loc, msg) -> Some (loc, msg) + | _ -> None in + let title = "Assertion error" in + let description = "An fatal assertion" in + let encoding_case = + let open Data_encoding in + case + (describe ~title ~description @@ + conv (fun (x, y) -> ((), x, y)) (fun ((), x, y) -> (x, y)) @@ + (obj3 + (req "kind" (constant "assertion")) + (req "location" string) + (req "error" string))) + from_error to_error in + let pp ppf (loc, msg) = + Format.fprintf ppf + "Assert failure (%s)%s" + loc + (if msg = "" then "." else ": " ^ msg) in + error_kinds := + Error_kind { id; from_error ; category; encoding_case ; pp } :: !error_kinds + +let _assert b loc fmt = + if b then + Format.ikfprintf (fun _ -> return ()) Format.str_formatter fmt + else + Format.kasprintf (fun msg -> fail (Assert_error (loc, msg))) fmt + + let protect ~on_error t = t >>= function | Ok res -> return res diff --git a/src/utils/error_monad_sig.ml b/src/utils/error_monad_sig.ml index 493d3f000..fa7680236 100644 --- a/src/utils/error_monad_sig.ml +++ b/src/utils/error_monad_sig.ml @@ -99,8 +99,15 @@ module type S = sig (** Erroneous return on failed assertion *) val fail_unless : bool -> error -> unit tzresult Lwt.t + val fail_when : bool -> error -> unit tzresult Lwt.t val unless : bool -> (unit -> unit tzresult Lwt.t) -> unit tzresult Lwt.t + val _when : bool -> (unit -> unit tzresult Lwt.t) -> unit tzresult Lwt.t + + (* Usage: [_assert cond __LOC__ "" ...] *) + val _assert : + bool -> string -> + ('a, Format.formatter, unit, unit tzresult Lwt.t) format4 -> 'a val protect : on_error: (error list -> 'a tzresult Lwt.t) ->