Shell: improve Lwt_utils
This commit is contained in:
parent
0027d70158
commit
6b3e002285
@ -581,7 +581,7 @@ module Make (P: PARAMS) = struct
|
||||
"(%a) connection handler for %a:%d"
|
||||
pp_gid my_gid Ipaddr.pp_hum addr port in
|
||||
ignore (Lwt_utils.worker worker_name
|
||||
~safe:true ~run:(fun () -> connect buf) ~cancel) ;
|
||||
~run:(fun () -> connect buf) ~cancel) ;
|
||||
(* return the canceler *)
|
||||
cancel
|
||||
|
||||
|
@ -7,7 +7,6 @@
|
||||
(* *)
|
||||
(**************************************************************************)
|
||||
|
||||
|
||||
exception Exit
|
||||
|
||||
let termination_thread, exit_wakener = Lwt.wait ()
|
||||
@ -18,6 +17,12 @@ let () =
|
||||
(function
|
||||
| Exit -> ()
|
||||
| exn ->
|
||||
Printf.eprintf "Uncaught (asynchronous) exception: %S\n%s\n%!"
|
||||
(Printexc.to_string exn) (Printexc.get_backtrace ());
|
||||
Format.eprintf
|
||||
"@[Uncaught (asynchronous) exception (%d):@ %a@]"
|
||||
(Unix.getpid ())
|
||||
Error_monad.pp_exn exn ;
|
||||
let backtrace = Printexc.get_backtrace () in
|
||||
if String.length backtrace <> 0 then
|
||||
Format.eprintf "\n%s" backtrace ;
|
||||
Format.eprintf "@." ;
|
||||
Lwt.wakeup exit_wakener 1)
|
||||
|
@ -12,7 +12,7 @@ module LC = Lwt_condition
|
||||
open Lwt.Infix
|
||||
open Logging.Core
|
||||
|
||||
let may f = function
|
||||
let may ~f = function
|
||||
| None -> Lwt.return_unit
|
||||
| Some x -> f x
|
||||
|
||||
@ -39,10 +39,13 @@ let canceler ()
|
||||
else begin
|
||||
canceling := true ;
|
||||
LC.broadcast cancelation () ;
|
||||
!cancel_hook () >>= fun () ->
|
||||
canceled := true ;
|
||||
LC.broadcast cancelation_complete () ;
|
||||
Lwt.return ()
|
||||
Lwt.finalize
|
||||
!cancel_hook
|
||||
(fun () ->
|
||||
canceled := true ;
|
||||
LC.broadcast cancelation_complete () ;
|
||||
Lwt.return ()) >>= fun () ->
|
||||
Lwt.return_unit
|
||||
end
|
||||
in
|
||||
let on_cancel cb =
|
||||
@ -55,6 +58,53 @@ let canceler ()
|
||||
in
|
||||
cancelation, cancel, on_cancel
|
||||
|
||||
module Canceler = struct
|
||||
|
||||
type t = {
|
||||
cancelation: unit Lwt_condition.t ;
|
||||
cancelation_complete: unit Lwt_condition.t ;
|
||||
mutable cancel_hook: unit -> unit Lwt.t ;
|
||||
mutable canceling: bool ;
|
||||
mutable canceled: bool ;
|
||||
}
|
||||
|
||||
let create () =
|
||||
let cancelation = LC.create () in
|
||||
let cancelation_complete = LC.create () in
|
||||
{ cancelation ; cancelation_complete ;
|
||||
cancel_hook = (fun () -> Lwt.return ()) ;
|
||||
canceling = false ;
|
||||
canceled = false ;
|
||||
}
|
||||
|
||||
let cancel st =
|
||||
if st.canceled then
|
||||
Lwt.return ()
|
||||
else if st.canceling then
|
||||
LC.wait st.cancelation_complete
|
||||
else begin
|
||||
st.canceling <- true ;
|
||||
LC.broadcast st.cancelation () ;
|
||||
Lwt.finalize
|
||||
st.cancel_hook
|
||||
(fun () ->
|
||||
st.canceled <- true ;
|
||||
LC.broadcast st.cancelation_complete () ;
|
||||
Lwt.return ())
|
||||
end
|
||||
|
||||
let on_cancel st cb =
|
||||
let hook = st.cancel_hook in
|
||||
st.cancel_hook <- (fun () -> hook () >>= cb)
|
||||
|
||||
let cancelation st =
|
||||
if st.canceling then Lwt.return ()
|
||||
else LC.wait st.cancelation
|
||||
|
||||
let canceled st = st.canceling
|
||||
|
||||
end
|
||||
|
||||
type trigger =
|
||||
| Absent
|
||||
| Present
|
||||
@ -114,12 +164,11 @@ let queue () : ('a -> unit) * (unit -> 'a list Lwt.t) =
|
||||
queue, wait
|
||||
|
||||
(* A worker launcher, takes a cancel callback to call upon *)
|
||||
let worker ?(safe=false) name ~run ~cancel =
|
||||
let worker name ~run ~cancel =
|
||||
let stop = LC.create () in
|
||||
let fail e =
|
||||
log_error "%s worker failed with %s" name (Printexc.to_string e) ;
|
||||
cancel () >>= fun () ->
|
||||
if safe then Lwt.return_unit else Lwt.fail e
|
||||
cancel ()
|
||||
in
|
||||
let waiter = LC.wait stop in
|
||||
log_info "%s worker started" name ;
|
||||
@ -263,6 +312,17 @@ let write_mbytes ?(pos=0) ?len descr buf =
|
||||
| nb_written -> inner (pos + nb_written) (len - nb_written) in
|
||||
inner pos len
|
||||
|
||||
let write_bytes ?(pos=0) ?len descr buf =
|
||||
let len = match len with None -> Bytes.length buf - pos | Some l -> l in
|
||||
let rec inner pos len =
|
||||
if len = 0 then
|
||||
Lwt.return_unit
|
||||
else
|
||||
Lwt_unix.write descr buf pos len >>= function
|
||||
| 0 -> Lwt.fail End_of_file (* other endpoint cleanly closed its connection *)
|
||||
| nb_written -> inner (pos + nb_written) (len - nb_written) in
|
||||
inner pos len
|
||||
|
||||
let (>>=) = Lwt.bind
|
||||
|
||||
let remove_dir dir =
|
||||
@ -297,3 +357,49 @@ let create_file ?(perm = 0o644) name content =
|
||||
Lwt_unix.openfile name Unix.([O_TRUNC; O_CREAT; O_WRONLY]) perm >>= fun fd ->
|
||||
Lwt_unix.write_string fd content 0 (String.length content) >>= fun _ ->
|
||||
Lwt_unix.close fd
|
||||
|
||||
let safe_close fd =
|
||||
Lwt.catch
|
||||
(fun () -> Lwt_unix.close fd)
|
||||
(fun _ -> Lwt.return_unit)
|
||||
|
||||
open Error_monad
|
||||
|
||||
type error += Canceled
|
||||
|
||||
let protect ?on_error ?canceler t =
|
||||
let cancelation =
|
||||
match canceler with
|
||||
| None -> never_ending
|
||||
| Some canceler ->
|
||||
( Canceler.cancelation canceler >>= fun () ->
|
||||
fail Canceled ) in
|
||||
let res =
|
||||
Lwt.pick [ cancelation ;
|
||||
Lwt.catch t (fun exn -> fail (Exn exn)) ] in
|
||||
res >>= function
|
||||
| Ok _ -> res
|
||||
| Error err ->
|
||||
let canceled =
|
||||
Utils.unopt_map canceler ~default:false ~f:Canceler.canceled in
|
||||
let err = if canceled then [Canceled] else err in
|
||||
match on_error with
|
||||
| None -> Lwt.return (Error err)
|
||||
| Some on_error -> on_error err
|
||||
|
||||
type error += Timeout
|
||||
|
||||
let with_timeout ?(canceler = Canceler.create ()) timeout f =
|
||||
let t = Lwt_unix.sleep timeout in
|
||||
Lwt.choose [
|
||||
(t >|= fun () -> None) ;
|
||||
(f canceler >|= fun x -> Some x)
|
||||
] >>= function
|
||||
| Some x when Lwt.state t = Lwt.Sleep ->
|
||||
Lwt.cancel t ;
|
||||
Lwt.return x
|
||||
| _ ->
|
||||
Canceler.cancel canceler >>= fun () ->
|
||||
fail Timeout
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
(* *)
|
||||
(**************************************************************************)
|
||||
|
||||
val may : ('a -> unit Lwt.t) -> 'a option -> unit Lwt.t
|
||||
val may: f:('a -> unit Lwt.t) -> 'a option -> unit Lwt.t
|
||||
|
||||
val never_ending: 'a Lwt.t
|
||||
|
||||
@ -16,8 +16,18 @@ val canceler : unit ->
|
||||
(unit -> unit Lwt.t) *
|
||||
((unit -> unit Lwt.t) -> unit)
|
||||
|
||||
module Canceler : sig
|
||||
|
||||
type t
|
||||
val create : unit -> t
|
||||
val cancel : t -> unit Lwt.t
|
||||
val cancelation : t -> unit Lwt.t
|
||||
val on_cancel : t -> (unit -> unit Lwt.t) -> unit
|
||||
val canceled : t -> bool
|
||||
|
||||
end
|
||||
|
||||
val worker:
|
||||
?safe:bool ->
|
||||
string ->
|
||||
run:(unit -> unit Lwt.t) ->
|
||||
cancel:(unit -> unit Lwt.t) ->
|
||||
@ -33,9 +43,27 @@ val read_bytes:
|
||||
val read_mbytes:
|
||||
?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t
|
||||
|
||||
val write_bytes:
|
||||
?pos:int -> ?len:int -> Lwt_unix.file_descr -> bytes -> unit Lwt.t
|
||||
val write_mbytes:
|
||||
?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t
|
||||
|
||||
val remove_dir: string -> unit Lwt.t
|
||||
val create_dir: ?perm:int -> string -> unit Lwt.t
|
||||
val create_file: ?perm:int -> string -> string -> unit Lwt.t
|
||||
|
||||
val safe_close: Lwt_unix.file_descr -> unit Lwt.t
|
||||
|
||||
open Error_monad
|
||||
|
||||
type error += Canceled
|
||||
val protect :
|
||||
?on_error:(error list -> 'a tzresult Lwt.t) ->
|
||||
?canceler:Canceler.t ->
|
||||
(unit -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t
|
||||
|
||||
type error += Timeout
|
||||
val with_timeout:
|
||||
?canceler:Canceler.t ->
|
||||
float -> (Canceler.t -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user