Shell: introduce Error_monad._assert.

This commit is contained in:
Grégoire Henry 2017-04-10 00:09:01 +02:00
parent 18e2edf6f4
commit 4537c8780e
2 changed files with 49 additions and 0 deletions

View File

@ -299,9 +299,15 @@ module Make() = struct
let fail_unless cond exn =
if cond then return () else fail exn
let fail_when cond exn =
if cond then fail exn else return ()
let unless cond f =
if cond then return () else f ()
let _when cond f =
if cond then f () else return ()
let pp_print_error ppf errors =
match errors with
| [] ->
@ -339,6 +345,42 @@ let () =
error_kinds :=
Error_kind { id; from_error ; category; encoding_case ; pp } :: !error_kinds
type error += Assert_error of string * string
let () =
let id = "" in
let category = `Permanent in
let to_error (loc, msg) = Assert_error (loc, msg) in
let from_error = function
| Assert_error (loc, msg) -> Some (loc, msg)
| _ -> None in
let title = "Assertion error" in
let description = "An fatal assertion" in
let encoding_case =
let open Data_encoding in
case
(describe ~title ~description @@
conv (fun (x, y) -> ((), x, y)) (fun ((), x, y) -> (x, y)) @@
(obj3
(req "kind" (constant "assertion"))
(req "location" string)
(req "error" string)))
from_error to_error in
let pp ppf (loc, msg) =
Format.fprintf ppf
"Assert failure (%s)%s"
loc
(if msg = "" then "." else ": " ^ msg) in
error_kinds :=
Error_kind { id; from_error ; category; encoding_case ; pp } :: !error_kinds
let _assert b loc fmt =
if b then
Format.ikfprintf (fun _ -> return ()) Format.str_formatter fmt
else
Format.kasprintf (fun msg -> fail (Assert_error (loc, msg))) fmt
let protect ~on_error t =
t >>= function
| Ok res -> return res

View File

@ -99,8 +99,15 @@ module type S = sig
(** Erroneous return on failed assertion *)
val fail_unless : bool -> error -> unit tzresult Lwt.t
val fail_when : bool -> error -> unit tzresult Lwt.t
val unless : bool -> (unit -> unit tzresult Lwt.t) -> unit tzresult Lwt.t
val _when : bool -> (unit -> unit tzresult Lwt.t) -> unit tzresult Lwt.t
(* Usage: [_assert cond __LOC__ "<fmt>" ...] *)
val _assert :
bool -> string ->
('a, Format.formatter, unit, unit tzresult Lwt.t) format4 -> 'a
val protect :
on_error: (error list -> 'a tzresult Lwt.t) ->