diff --git a/src/node/net/p2p.ml b/src/node/net/p2p.ml index 559eb6b9a..2993d7dfe 100644 --- a/src/node/net/p2p.ml +++ b/src/node/net/p2p.ml @@ -176,7 +176,7 @@ module Real = struct P2p_maintenance.shutdown net.maintenance >>= fun () -> Lwt_utils.may ~f:P2p_discovery.shutdown net.discoverer >>= fun () -> P2p_connection_pool.destroy net.pool >>= fun () -> - P2p_io_scheduler.shutdown net.io_sched + P2p_io_scheduler.shutdown ~timeout:3.0 net.io_sched let connections { pool } () = P2p_connection_pool.fold_connections pool @@ -200,18 +200,24 @@ module Real = struct let rec recv_any net () = let pipes = P2p_connection_pool.fold_connections - net.pool ~init:[] ~f:begin fun _gid conn acc -> + net.pool ~init:[] + ~f:begin fun _gid conn acc -> (P2p_connection_pool.is_readable conn >>= function - | Ok () -> Lwt.return conn + | Ok () -> Lwt.return (Some conn) | Error _ -> Lwt_utils.never_ending) :: acc end in - Lwt.pick pipes >>= fun conn -> - P2p_connection_pool.read conn >>= function - | Ok msg -> - Lwt.return (conn, msg) - | Error _ -> - Lwt_unix.yield () >>= fun () -> - recv_any net () + Lwt.pick ( + ( P2p_connection_pool.Events.new_connection net.pool >>= fun () -> + Lwt.return_none ):: + pipes) >>= function + | None -> recv_any net () + | Some conn -> + P2p_connection_pool.read conn >>= function + | Ok msg -> + Lwt.return (conn, msg) + | Error _ -> + Lwt_unix.yield () >>= fun () -> + recv_any net () let send _net c m = P2p_connection_pool.write c m >>= function diff --git a/src/node/net/p2p_connection_pool.ml b/src/node/net/p2p_connection_pool.ml index 23d2eb09e..ef0852f74 100644 --- a/src/node/net/p2p_connection_pool.ml +++ b/src/node/net/p2p_connection_pool.ml @@ -7,8 +7,6 @@ (* *) (**************************************************************************) -(* TODO check version negotiation *) - (* TODO Test cancelation of a (pending) connection *) (* TODO do not recompute list_known_points at each requests... but @@ -177,6 +175,7 @@ and events = { too_few_connections : unit Lwt_condition.t ; too_many_connections : unit Lwt_condition.t ; new_point : unit Lwt_condition.t ; + new_connection : unit Lwt_condition.t ; } and ('msg, 'meta) connection = { @@ -245,7 +244,7 @@ let list_known_points pool _gid () = let active_connections pool = Gid.Table.length pool.connected_gids -let create_connection pool conn id_point pi gi = +let create_connection pool conn id_point pi gi _version = let gid = Gid_info.gid gi in let canceler = Canceler.create () in let size = @@ -268,6 +267,7 @@ let create_connection pool conn id_point pi gi = end ; Gid_info.State.set_running gi id_point conn ; Gid.Table.add pool.connected_gids gid gi ; + Lwt_condition.broadcast pool.events.new_connection () ; Canceler.on_cancel canceler begin fun () -> lwt_debug "Disconnect: %a (%a)" Gid.pp gid Id_point.pp id_point >>= fun () -> @@ -338,6 +338,9 @@ let authenticate pool ?pi canceler fd point = | None, None -> None | Some _ as pi, _ | _, (Some _ as pi) -> pi in let gi = register_peer pool info.gid in + let acceptable_versions = + Version.common info.versions pool.message_config.versions + in let acceptable_point = unopt_map connection_pi ~default:(not pool.config.closed_network) @@ -359,47 +362,49 @@ let authenticate pool ?pi canceler fd point = | Disconnected -> true in if incoming then Point.Table.remove pool.incoming point ; - if not acceptable_gid || not acceptable_point then begin - lwt_debug "authenticate: %a -> kick %a point: %B gid: %B" - Point.pp point - Connection_info.pp info - acceptable_point acceptable_gid >>= fun () -> - P2p_connection.kick auth_fd >>= fun () -> - if not incoming then begin - iter_option ~f:Point_info.State.set_disconnected pi ; - (* FIXME Gid_info.State.set_disconnected ~requested:true gi ; *) - end ; - fail (Rejected info.gid) - end else begin - iter_option connection_pi - ~f:(fun pi -> Point_info.State.set_accepted pi info.gid canceler) ; - Gid_info.State.set_accepted gi info.id_point canceler ; - lwt_debug "authenticate: %a -> accept %a" - Point.pp point - Connection_info.pp info >>= fun () -> - Lwt_utils.protect ~canceler begin fun () -> - P2p_connection.accept - ?incoming_message_queue_size:pool.config.incoming_message_queue_size - ?outgoing_message_queue_size:pool.config.outgoing_message_queue_size - auth_fd pool.encoding >>= fun conn -> - lwt_debug "authenticate: %a -> Connected %a" + match acceptable_versions with + | Some version when acceptable_gid && acceptable_point -> begin + iter_option connection_pi + ~f:(fun pi -> Point_info.State.set_accepted pi info.gid canceler) ; + Gid_info.State.set_accepted gi info.id_point canceler ; + lwt_debug "authenticate: %a -> accept %a" Point.pp point Connection_info.pp info >>= fun () -> - Lwt.return conn - end ~on_error: begin fun err -> - lwt_debug "authenticate: %a -> rejected %a" + Lwt_utils.protect ~canceler begin fun () -> + P2p_connection.accept + ?incoming_message_queue_size:pool.config.incoming_message_queue_size + ?outgoing_message_queue_size:pool.config.outgoing_message_queue_size + auth_fd pool.encoding >>= fun conn -> + lwt_debug "authenticate: %a -> Connected %a" + Point.pp point + Connection_info.pp info >>= fun () -> + Lwt.return conn + end ~on_error: begin fun err -> + lwt_debug "authenticate: %a -> rejected %a" + Point.pp point + Connection_info.pp info >>= fun () -> + iter_option connection_pi ~f:Point_info.State.set_disconnected; + Gid_info.State.set_disconnected gi ; + Lwt.return (Error err) + end >>=? fun conn -> + let id_point = + match info.id_point, map_option Point_info.point pi with + | (addr, _), Some (_, port) -> addr, Some port + | id_point, None -> id_point in + return (create_connection pool conn id_point connection_pi gi version) + end + | _ -> begin + lwt_debug "authenticate: %a -> kick %a point: %B gid: %B" Point.pp point - Connection_info.pp info >>= fun () -> - iter_option connection_pi ~f:Point_info.State.set_disconnected; - Gid_info.State.set_disconnected gi ; - Lwt.return (Error err) - end >>=? fun conn -> - let id_point = - match info.id_point, map_option Point_info.point pi with - | (addr, _), Some (_, port) -> addr, Some port - | id_point, None -> id_point in - return (create_connection pool conn id_point connection_pi gi) - end + Connection_info.pp info + acceptable_point acceptable_gid >>= fun () -> + P2p_connection.kick auth_fd >>= fun () -> + if not incoming then begin + iter_option ~f:Point_info.State.set_disconnected pi ; + (* FIXME Gid_info.State.set_disconnected ~requested:true gi ; *) + end ; + fail (Rejected info.gid) + end type error += Pending_connection type error += Connected @@ -437,6 +442,7 @@ let raw_connect canceler pool point = end ~on_error: begin fun err -> lwt_debug "connect: %a -> disconnect" Point.pp point >>= fun () -> Point_info.State.set_disconnected pi ; + Lwt_utils.safe_close fd >>= fun () -> match err with | [Exn (Unix.Unix_error (Unix.ECONNREFUSED, _, _))] -> fail Connection_refused @@ -604,6 +610,8 @@ module Events = struct Lwt_condition.wait pool.events.too_many_connections let new_point pool = Lwt_condition.wait pool.events.new_point + let new_connection pool = + Lwt_condition.wait pool.events.new_connection end @@ -623,6 +631,7 @@ let create config meta_config message_config io_sched = too_few_connections = Lwt_condition.create () ; too_many_connections = Lwt_condition.create () ; new_point = Lwt_condition.create () ; + new_connection = Lwt_condition.create () ; } in let pool = { config ; meta_config ; message_config ; diff --git a/src/node/net/p2p_connection_pool.mli b/src/node/net/p2p_connection_pool.mli index 27ef938c0..7eaf08445 100644 --- a/src/node/net/p2p_connection_pool.mli +++ b/src/node/net/p2p_connection_pool.mli @@ -131,6 +131,7 @@ module Events : sig val too_few_connections: ('msg, 'meta) pool -> unit Lwt.t val too_many_connections: ('msg, 'meta) pool -> unit Lwt.t val new_point: ('msg, 'meta) pool -> unit Lwt.t + val new_connection: ('msg, 'meta) pool -> unit Lwt.t end (** {1 Connections management} *) diff --git a/src/node/net/p2p_io_scheduler.ml b/src/node/net/p2p_io_scheduler.ml index 7b3fcd164..216cc4181 100644 --- a/src/node/net/p2p_io_scheduler.ml +++ b/src/node/net/p2p_io_scheduler.ml @@ -29,6 +29,8 @@ module type IO = sig val close: out_param -> error list -> unit Lwt.t end +type error += Connection_closed + module Scheduler(IO : IO) = struct type t = { @@ -111,8 +113,9 @@ module Scheduler(IO : IO) = struct match msg with | Error [Lwt_utils.Canceled] -> worker_loop st - | Error ([Exn (Lwt_pipe.Closed | - Unix.Unix_error (EBADF, _, _))] as err) -> + | Error ([Connection_closed | + Exn ( Lwt_pipe.Closed | + Unix.Unix_error (EBADF, _, _) )] as err) -> cancel conn err >>= fun () -> worker_loop st | Error err -> @@ -125,7 +128,8 @@ module Scheduler(IO : IO) = struct | Ok () | Error [Lwt_utils.Canceled] -> return () - | Error ([Exn (Unix.Unix_error (EBADF, _, _) | + | Error ([Connection_closed | + Exn (Unix.Unix_error (EBADF, _, _) | Lwt_pipe.Closed)] as err) -> cancel conn err >>= fun () -> return () @@ -196,8 +200,6 @@ module Scheduler(IO : IO) = struct end -type error += Connection_closed - module ReadScheduler = Scheduler(struct let name = "io_scheduler(read)" type in_param = Lwt_unix.file_descr * int @@ -239,6 +241,7 @@ module WriteScheduler = Scheduler(struct (fun () -> Lwt_utils.write_mbytes fd buf >>= return) (function + | Unix.Unix_error(Unix.ECONNRESET, _, _) | Unix.Unix_error(Unix.EPIPE, _, _) | Lwt.Canceled | End_of_file -> @@ -440,21 +443,32 @@ let stat { read_conn ; write_conn} = and ws = Moving_average.stat write_conn.counter in convert ~rs ~ws -let close conn = +let close ?timeout conn = Inttbl.remove conn.sched.connected conn.id ; Lwt_pipe.close conn.write_queue ; - Canceler.cancelation conn.canceler >>= fun () -> + begin + match timeout with + | None -> + return (Canceler.cancelation conn.canceler) + | Some timeout -> + Lwt_utils.with_timeout + ~canceler:conn.canceler timeout begin fun canceler -> + return (Canceler.cancelation canceler) + end + end >>=? fun _ -> conn.write_conn.current_push >>= fun res -> Lwt.return res let iter_connection { connected } f = Inttbl.iter f connected -let shutdown st = +let shutdown ?timeout st = + lwt_log_info "--> shutdown" >>= fun () -> st.closed <- true ; ReadScheduler.shutdown st.read_scheduler >>= fun () -> - WriteScheduler.shutdown st.write_scheduler >>= fun () -> Inttbl.fold - (fun _gid conn acc -> close conn >>= fun _ -> acc) + (fun _gid conn acc -> close ?timeout conn >>= fun _ -> acc) st.connected - Lwt.return_unit + Lwt.return_unit >>= fun () -> + WriteScheduler.shutdown st.write_scheduler >>= fun () -> + Lwt.return_unit diff --git a/src/node/net/p2p_io_scheduler.mli b/src/node/net/p2p_io_scheduler.mli index f5641ff35..363ead969 100644 --- a/src/node/net/p2p_io_scheduler.mli +++ b/src/node/net/p2p_io_scheduler.mli @@ -83,11 +83,11 @@ val iter_connection: t -> (int -> connection -> unit) -> unit (** [iter_connection sched f] applies [f] on each connection managed by [sched]. *) -val close: connection -> unit tzresult Lwt.t +val close: ?timeout:float -> connection -> unit tzresult Lwt.t (** [close conn] cancels [conn] and returns after any pending data has been sent. *) -val shutdown: t -> unit Lwt.t +val shutdown: ?timeout:float -> t -> unit Lwt.t (** [shutdown sched] returns after all connections managed by [sched] have been closed and [sched]'s inner worker has successfully canceled. *) diff --git a/src/node/net/p2p_maintenance.ml b/src/node/net/p2p_maintenance.ml index 1ff34c9f0..0def0ede7 100644 --- a/src/node/net/p2p_maintenance.ml +++ b/src/node/net/p2p_maintenance.ml @@ -57,7 +57,7 @@ let connectable st start_time expected = | Disconnected -> begin match Point_info.last_miss pi with | Some last when Time.(start_time < last) - && not (Point_info.greylisted ~now pi) -> () + || Point_info.greylisted ~now pi -> () | last -> Bounded_point_info.insert (last, point) acc end