diff --git a/src/lib_stdlib_unix/lwt_utils_unix.ml b/src/lib_stdlib_unix/lwt_utils_unix.ml index 3c95ad6d9..0fd6447cd 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.ml +++ b/src/lib_stdlib_unix/lwt_utils_unix.ml @@ -288,7 +288,7 @@ module Socket = struct | Error (`Msg _) -> host | Ok ipaddr -> Ipaddr.to_string ipaddr - let connect = function + let connect ?(timeout=5.) = function | Unix path -> let addr = Lwt_unix.ADDR_UNIX path in let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in @@ -300,22 +300,24 @@ module Socket = struct | [] -> failwith "could not resolve host '%s'" host | addrs -> - let rec try_connect = function + let rec try_connect acc = function | [] -> - failwith "could not connect to '%s'" host + Lwt.return + (Error (failure "could not connect to '%s'" host :: List.rev acc)) | { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } :: addrs -> let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in - Lwt.catch - (fun () -> - Lwt_unix.connect sock ai_addr >>= fun () -> - return sock) - (fun exn -> - Format.printf "@{@{Unable to connect to %s@}@}@.\ - \ @[<h 0>%a@]@." - host Format.pp_print_text (Printexc.to_string exn) ; - Lwt_unix.close sock >>= fun () -> - try_connect addrs) in - try_connect addrs + protect ~on_error:begin fun e -> + Lwt_unix.close sock >>= fun () -> + Lwt.return (Error e) + end begin fun () -> + with_timeout (Lwt_unix.sleep timeout) (fun _c -> + Lwt_unix.connect sock ai_addr >>= fun () -> + return sock) + end >>= function + | Ok sock -> return sock + | Error e -> + try_connect (e @ acc) addrs in + try_connect [] addrs let bind ?(backlog = 10) = function | Unix path -> diff --git a/src/lib_stdlib_unix/lwt_utils_unix.mli b/src/lib_stdlib_unix/lwt_utils_unix.mli index 9058e8126..98a3467bf 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.mli +++ b/src/lib_stdlib_unix/lwt_utils_unix.mli @@ -78,7 +78,18 @@ module Socket : sig | Unix of string | Tcp of string * string * Unix.getaddrinfo_option list - val connect: addr -> Lwt_unix.file_descr tzresult Lwt.t + val connect: + ?timeout:float -> addr -> Lwt_unix.file_descr tzresult Lwt.t + (** [connect ?timeout addr] tries connecting to [addr] and returns + the resulting socket file descriptor on success. When using TCP, + [Unix.getaddrinfo] is used to resolve the hostname and service + (port). The different socket addresses returned by + [Unix.getaddrinfo] are tried sequentially, and the [?timeout] + argument (default: 5s) governs how long it waits to get a + connection. If a connection is not obtained in less than + [?timeout], the connection is canceled and and the next socket + address (if it exists) is tried. *) + val bind: ?backlog:int -> addr -> Lwt_unix.file_descr list tzresult Lwt.t