From f458b6119cd23a1a4148d5941d3cd60b5b27c91b Mon Sep 17 00:00:00 2001 From: Vincent Bernardoff Date: Tue, 4 Sep 2018 00:23:51 +0900 Subject: [PATCH] Lwt_utils: connect: add a ?timeout argument This is to replace the default UNIX timeout that can be very long (30s or so). In the context of baking, it is not acceptable to wait for such a long time to connect to e.g. a signer daemon whenever there is multiple addresses available for load balancing. --- src/lib_stdlib_unix/lwt_utils_unix.ml | 30 ++++++++++++++------------ src/lib_stdlib_unix/lwt_utils_unix.mli | 13 ++++++++++- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/lib_stdlib_unix/lwt_utils_unix.ml b/src/lib_stdlib_unix/lwt_utils_unix.ml index 3c95ad6d9..0fd6447cd 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.ml +++ b/src/lib_stdlib_unix/lwt_utils_unix.ml @@ -288,7 +288,7 @@ module Socket = struct | Error (`Msg _) -> host | Ok ipaddr -> Ipaddr.to_string ipaddr - let connect = function + let connect ?(timeout=5.) = function | Unix path -> let addr = Lwt_unix.ADDR_UNIX path in let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in @@ -300,22 +300,24 @@ module Socket = struct | [] -> failwith "could not resolve host '%s'" host | addrs -> - let rec try_connect = function + let rec try_connect acc = function | [] -> - failwith "could not connect to '%s'" host + Lwt.return + (Error (failure "could not connect to '%s'" host :: List.rev acc)) | { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } :: addrs -> let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in - Lwt.catch - (fun () -> - Lwt_unix.connect sock ai_addr >>= fun () -> - return sock) - (fun exn -> - Format.printf "@{@{Unable to connect to %s@}@}@.\ - \ @[<h 0>%a@]@." - host Format.pp_print_text (Printexc.to_string exn) ; - Lwt_unix.close sock >>= fun () -> - try_connect addrs) in - try_connect addrs + protect ~on_error:begin fun e -> + Lwt_unix.close sock >>= fun () -> + Lwt.return (Error e) + end begin fun () -> + with_timeout (Lwt_unix.sleep timeout) (fun _c -> + Lwt_unix.connect sock ai_addr >>= fun () -> + return sock) + end >>= function + | Ok sock -> return sock + | Error e -> + try_connect (e @ acc) addrs in + try_connect [] addrs let bind ?(backlog = 10) = function | Unix path -> diff --git a/src/lib_stdlib_unix/lwt_utils_unix.mli b/src/lib_stdlib_unix/lwt_utils_unix.mli index 9058e8126..98a3467bf 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.mli +++ b/src/lib_stdlib_unix/lwt_utils_unix.mli @@ -78,7 +78,18 @@ module Socket : sig | Unix of string | Tcp of string * string * Unix.getaddrinfo_option list - val connect: addr -> Lwt_unix.file_descr tzresult Lwt.t + val connect: + ?timeout:float -> addr -> Lwt_unix.file_descr tzresult Lwt.t + (** [connect ?timeout addr] tries connecting to [addr] and returns + the resulting socket file descriptor on success. When using TCP, + [Unix.getaddrinfo] is used to resolve the hostname and service + (port). The different socket addresses returned by + [Unix.getaddrinfo] are tried sequentially, and the [?timeout] + argument (default: 5s) governs how long it waits to get a + connection. If a connection is not obtained in less than + [?timeout], the connection is canceled and and the next socket + address (if it exists) is tried. *) + val bind: ?backlog:int -> addr -> Lwt_unix.file_descr list tzresult Lwt.t