From 07595c9e1fdf0cb0a4e7f2ea0530b13bc822ee32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9goire=20Henry?= Date: Wed, 20 Feb 2019 14:44:38 +0100 Subject: [PATCH] P2p: improve cancelation Pending connections were not easily interuptible. --- src/lib_p2p/p2p_io_scheduler.ml | 19 +++++---- src/lib_p2p/p2p_io_scheduler.mli | 6 ++- src/lib_p2p/p2p_pool.ml | 2 + src/lib_p2p/p2p_socket.ml | 60 ++++++++++++++--------------- src/lib_p2p/p2p_socket.mli | 2 + src/lib_p2p/test/test_p2p_socket.ml | 30 +++++++++------ 6 files changed, 68 insertions(+), 51 deletions(-) diff --git a/src/lib_p2p/p2p_io_scheduler.ml b/src/lib_p2p/p2p_io_scheduler.ml index 6dc8e50a4..0d7d3fc9a 100644 --- a/src/lib_p2p/p2p_io_scheduler.ml +++ b/src/lib_p2p/p2p_io_scheduler.ml @@ -390,10 +390,11 @@ let register st conn = conn end -let write { write_queue } msg = - Lwt.catch - (fun () -> Lwt_pipe.push write_queue msg >>= return) - (fun _ -> fail P2p_errors.Connection_closed) +let write ?canceler { write_queue } msg = + trace P2p_errors.Connection_closed @@ + protect ?canceler begin fun () -> + Lwt_pipe.push write_queue msg >>= return + end let write_now { write_queue } msg = Lwt_pipe.push_now write_queue msg let read_from conn ?pos ?len buf msg = @@ -426,7 +427,7 @@ let read_now conn ?pos ?len buf = (Lwt_pipe.pop_now conn.read_queue) with Lwt_pipe.Closed -> Some (Error [P2p_errors.Connection_closed]) -let read conn ?pos ?len buf = +let read ?canceler conn ?pos ?len buf = match conn.partial_read with | Some msg -> conn.partial_read <- None ; @@ -434,11 +435,13 @@ let read conn ?pos ?len buf = | None -> Lwt.catch (fun () -> - Lwt_pipe.pop conn.read_queue >|= fun msg -> + protect ?canceler begin fun () -> + Lwt_pipe.pop conn.read_queue + end >|= fun msg -> read_from conn ?pos ?len buf msg) (fun _ -> fail P2p_errors.Connection_closed) -let read_full conn ?pos ?len buf = +let read_full ?canceler conn ?pos ?len buf = let maxlen = MBytes.length buf in let pos = Option.unopt ~default:0 pos in let len = Option.unopt ~default:(maxlen - pos) len in @@ -448,7 +451,7 @@ let read_full conn ?pos ?len buf = if len = 0 then return_unit else - read conn ~pos ~len buf >>=? fun read_len -> + read ?canceler conn ~pos ~len buf >>=? fun read_len -> loop (pos + read_len) (len - read_len) in loop pos len diff --git a/src/lib_p2p/p2p_io_scheduler.mli b/src/lib_p2p/p2p_io_scheduler.mli index 7235056f9..a259d5f3f 100644 --- a/src/lib_p2p/p2p_io_scheduler.mli +++ b/src/lib_p2p/p2p_io_scheduler.mli @@ -61,7 +61,9 @@ val create: val register: t -> P2p_fd.t -> connection (** [register sched fd] is a [connection] managed by [sched]. *) -val write: connection -> MBytes.t -> unit tzresult Lwt.t +val write: + ?canceler:Lwt_canceler.t -> + connection -> MBytes.t -> unit tzresult Lwt.t (** [write conn msg] returns [Ok ()] when [msg] has been added to [conn]'s write queue, or fail with an error. *) @@ -76,11 +78,13 @@ val read_now: [buf] starting at [pos]. *) val read: + ?canceler:Lwt_canceler.t -> connection -> ?pos:int -> ?len:int -> MBytes.t -> int tzresult Lwt.t (** Like [read_now], but waits till [conn] read queue has at least one element instead of failing. *) val read_full: + ?canceler:Lwt_canceler.t -> connection -> ?pos:int -> ?len:int -> MBytes.t -> unit tzresult Lwt.t (** Like [read], but blits exactly [len] bytes in [buf]. *) diff --git a/src/lib_p2p/p2p_pool.ml b/src/lib_p2p/p2p_pool.ml index 19ffbec8d..eb3587d89 100644 --- a/src/lib_p2p/p2p_pool.ml +++ b/src/lib_p2p/p2p_pool.ml @@ -783,6 +783,7 @@ and raw_authenticate pool ?point_info canceler fd point = (if incoming then " incoming" else "") >>= fun () -> protect ~canceler begin fun () -> P2p_socket.authenticate + ~canceler ~proof_of_work_target:pool.config.proof_of_work_target ~incoming fd point ?listening_port:pool.config.listening_port @@ -885,6 +886,7 @@ and raw_authenticate pool ?point_info canceler fd point = ?incoming_message_queue_size:pool.config.incoming_message_queue_size ?outgoing_message_queue_size:pool.config.outgoing_message_queue_size ?binary_chunks_size:pool.config.binary_chunks_size + ~canceler auth_fd pool.encoding >>=? fun conn -> lwt_debug "authenticate: %a -> Connected %a" P2p_point.Id.pp point diff --git a/src/lib_p2p/p2p_socket.ml b/src/lib_p2p/p2p_socket.ml index 2b7911d14..11b5c997f 100644 --- a/src/lib_p2p/p2p_socket.ml +++ b/src/lib_p2p/p2p_socket.ml @@ -56,7 +56,7 @@ module Crypto = struct input and output. *) let () = assert (Crypto_box.boxzerobytes >= header_length) - let write_chunk fd cryptobox_data msg = + let write_chunk ?canceler fd cryptobox_data msg = let msglen = MBytes.length msg in fail_unless (msglen <= max_content_length) P2p_errors.Invalid_message_size >>=? fun () -> @@ -71,15 +71,15 @@ module Crypto = struct let header_pos = Crypto_box.boxzerobytes - header_length in MBytes.set_int16 buf header_pos encrypted_length ; let payload = MBytes.sub buf header_pos (buf_length - header_pos) in - P2p_io_scheduler.write fd payload + P2p_io_scheduler.write ?canceler fd payload - let read_chunk fd cryptobox_data = + let read_chunk ?canceler fd cryptobox_data = let header_buf = MBytes.create header_length in - P2p_io_scheduler.read_full ~len:header_length fd header_buf >>=? fun () -> + P2p_io_scheduler.read_full ?canceler ~len:header_length fd header_buf >>=? fun () -> let encrypted_length = MBytes.get_uint16 header_buf 0 in let buf_length = encrypted_length + Crypto_box.boxzerobytes in let buf = MBytes.make buf_length '\x00' in - P2p_io_scheduler.read_full + P2p_io_scheduler.read_full ?canceler ~pos:Crypto_box.boxzerobytes ~len:encrypted_length fd buf >>=? fun () -> let remote_nonce = cryptobox_data.remote_nonce in cryptobox_data.remote_nonce <- Crypto_box.increment_nonce remote_nonce ; @@ -140,7 +140,7 @@ module Connection_message = struct (req "message_nonce" Crypto_box.nonce_encoding) (req "versions" (Variable.list P2p_version.encoding))) - let write fd message = + let write ~canceler fd message = let encoded_message_len = Data_encoding.Binary.length encoding message in fail_unless @@ -155,20 +155,20 @@ module Connection_message = struct | Some last -> fail_unless (last = len) P2p_errors.Encoding_error >>=? fun () -> MBytes.set_int16 buf 0 encoded_message_len ; - P2p_io_scheduler.write fd buf >>=? fun () -> + P2p_io_scheduler.write ~canceler fd buf >>=? fun () -> (* We return the raw message as it is used later to compute the nonces *) return buf - let read fd = + let read ~canceler fd = let header_buf = MBytes.create Crypto.header_length in - P2p_io_scheduler.read_full + P2p_io_scheduler.read_full ~canceler ~len:Crypto.header_length fd header_buf >>=? fun () -> let len = MBytes.get_uint16 header_buf 0 in let pos = Crypto.header_length in let buf = MBytes.create (pos + len) in MBytes.set_int16 buf 0 len ; - P2p_io_scheduler.read_full ~len ~pos fd buf >>=? fun () -> + P2p_io_scheduler.read_full ~canceler ~len ~pos fd buf >>=? fun () -> match Data_encoding.Binary.read encoding buf pos len with | None -> fail P2p_errors.Decoding_error @@ -188,7 +188,7 @@ type 'meta metadata_config = { module Metadata = struct - let write metadata_config cryptobox_data fd message = + let write ~canceler metadata_config cryptobox_data fd message = let encoded_message_len = Data_encoding.Binary.length metadata_config.conn_meta_encoding message in let buf = MBytes.create encoded_message_len in @@ -201,10 +201,10 @@ module Metadata = struct | Some last -> fail_unless (last = encoded_message_len) P2p_errors.Encoding_error >>=? fun () -> - Crypto.write_chunk cryptobox_data fd buf + Crypto.write_chunk ~canceler cryptobox_data fd buf - let read metadata_config fd cryptobox_data = - Crypto.read_chunk fd cryptobox_data >>=? fun buf -> + let read ~canceler metadata_config fd cryptobox_data = + Crypto.read_chunk ~canceler fd cryptobox_data >>=? fun buf -> let length = MBytes.length buf in let encoding = metadata_config.conn_meta_encoding in match @@ -248,7 +248,7 @@ module Ack = struct nack_case (Tag 255) ; ] - let write fd cryptobox_data message = + let write ?canceler fd cryptobox_data message = let encoded_message_len = Data_encoding.Binary.length encoding message in let buf = MBytes.create encoded_message_len in @@ -258,10 +258,10 @@ module Ack = struct | Some last -> fail_unless (last = encoded_message_len) P2p_errors.Encoding_error >>=? fun () -> - Crypto.write_chunk fd cryptobox_data buf + Crypto.write_chunk ?canceler fd cryptobox_data buf - let read fd cryptobox_data = - Crypto.read_chunk fd cryptobox_data >>=? fun buf -> + let read ?canceler fd cryptobox_data = + Crypto.read_chunk ?canceler fd cryptobox_data >>=? fun buf -> let length = MBytes.length buf in match Data_encoding.Binary.read encoding buf 0 length with | None -> @@ -289,18 +289,19 @@ let kick { fd ; cryptobox_data ; _ } = whether we're trying to connect to a peer or checking an incoming connection, both parties must first introduce themselves. *) let authenticate + ~canceler ~proof_of_work_target ~incoming fd (remote_addr, remote_socket_port as point) ?listening_port identity supported_versions metadata_config = let local_nonce_seed = Crypto_box.random_nonce () in lwt_debug "Sending authenfication to %a" P2p_point.Id.pp point >>= fun () -> - Connection_message.write fd + Connection_message.write ~canceler fd { public_key = identity.P2p_identity.public_key ; proof_of_work_stamp = identity.proof_of_work_stamp ; message_nonce = local_nonce_seed ; port = listening_port ; versions = supported_versions } >>=? fun sent_msg -> - Connection_message.read fd >>=? fun (msg, recv_msg) -> + Connection_message.read ~canceler fd >>=? fun (msg, recv_msg) -> let remote_listening_port = if incoming then msg.port else Some remote_socket_port in let id_point = remote_addr, remote_listening_port in @@ -318,8 +319,8 @@ let authenticate Crypto_box.generate_nonces ~incoming ~sent_msg ~recv_msg in let cryptobox_data = { Crypto.channel_key ; local_nonce ; remote_nonce } in let local_metadata = metadata_config.conn_meta_value remote_peer_id in - Metadata.write metadata_config fd cryptobox_data local_metadata >>=? fun () -> - Metadata.read metadata_config fd cryptobox_data >>=? fun remote_metadata -> + Metadata.write ~canceler metadata_config fd cryptobox_data local_metadata >>=? fun () -> + Metadata.read ~canceler metadata_config fd cryptobox_data >>=? fun remote_metadata -> let info = { P2p_connection.Info.peer_id = remote_peer_id ; versions = msg.versions ; incoming ; @@ -351,9 +352,8 @@ module Reader = struct lwt_debug "[read_message] incremental decoding error" >>= fun () -> return_none | Await decode_next_buf -> - protect ~canceler:st.canceler begin fun () -> - Crypto.read_chunk st.conn.fd st.conn.cryptobox_data - end >>=? fun buf -> + Crypto.read_chunk ~canceler:st.canceler + st.conn.fd st.conn.cryptobox_data >>=? fun buf -> lwt_debug "reading %d bytes from %a" (MBytes.length buf) P2p_peer.Id.pp st.conn.info.peer_id >>= fun () -> @@ -432,9 +432,8 @@ module Writer = struct let rec loop = function | [] -> return_unit | buf :: l -> - protect ~canceler:st.canceler begin fun () -> - Crypto.write_chunk st.conn.fd st.conn.cryptobox_data buf - end >>=? fun () -> + Crypto.write_chunk ~canceler:st.canceler + st.conn.fd st.conn.cryptobox_data buf >>=? fun () -> lwt_debug "writing %d bytes to %a" (MBytes.length buf) P2p_peer.Id.pp st.conn.info.peer_id >>= fun () -> loop l in @@ -561,11 +560,12 @@ let private_node { conn } = conn.info.private_node let accept ?incoming_message_queue_size ?outgoing_message_queue_size ?binary_chunks_size + ~canceler conn encoding = protect begin fun () -> - Ack.write conn.fd conn.cryptobox_data Ack >>=? fun () -> - Ack.read conn.fd conn.cryptobox_data + Ack.write ~canceler conn.fd conn.cryptobox_data Ack >>=? fun () -> + Ack.read ~canceler conn.fd conn.cryptobox_data end ~on_error:begin fun err -> P2p_io_scheduler.close conn.fd >>= fun _ -> match err with diff --git a/src/lib_p2p/p2p_socket.mli b/src/lib_p2p/p2p_socket.mli index 6dd9a2d17..d0e4a9e78 100644 --- a/src/lib_p2p/p2p_socket.mli +++ b/src/lib_p2p/p2p_socket.mli @@ -62,6 +62,7 @@ val private_node: ('msg, 'meta) t -> bool (** {1 Low-level functions (do not use directly)} *) val authenticate: + canceler:Lwt_canceler.t -> proof_of_work_target:Crypto_box.target -> incoming:bool -> P2p_io_scheduler.connection -> P2p_point.Id.t -> @@ -84,6 +85,7 @@ val accept: ?incoming_message_queue_size:int -> ?outgoing_message_queue_size:int -> ?binary_chunks_size: int -> + canceler:Lwt_canceler.t -> 'meta authenticated_connection -> 'msg Data_encoding.t -> ('msg, 'meta) t tzresult Lwt.t (** (Low-level) (Cancelable) Accepts a remote peer given an diff --git a/src/lib_p2p/test/test_p2p_socket.ml b/src/lib_p2p/test/test_p2p_socket.ml index 6fb8fa9a2..54bcd06be 100644 --- a/src/lib_p2p/test/test_p2p_socket.ml +++ b/src/lib_p2p/test/test_p2p_socket.ml @@ -27,6 +27,8 @@ include Logging.Make (struct let name = "test.p2p.connection" end) let addr = ref Ipaddr.V6.localhost +let canceler = Lwt_canceler.create () (* unused *) + let proof_of_work_target = Crypto_box.make_target 16. let id1 = P2p_identity.generate proof_of_work_target let id2 = P2p_identity.generate proof_of_work_target @@ -117,6 +119,7 @@ let raw_accept sched main_socket = let accept sched main_socket = raw_accept sched main_socket >>= fun (fd, point) -> P2p_socket.authenticate + ~canceler ~proof_of_work_target ~incoming:true fd point id1 versions conn_meta_config @@ -132,6 +135,7 @@ let raw_connect sched addr port = let connect sched addr port id = raw_connect sched addr port >>= fun fd -> P2p_socket.authenticate + ~canceler ~proof_of_work_target ~incoming:false fd (addr, port) id versions conn_meta_config >>=? fun (info, auth_fd) -> @@ -197,7 +201,7 @@ module Kick = struct let client _ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>= fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>= fun conn -> _assert (is_rejected conn) __LOC__ "" >>=? fun () -> return_unit @@ -211,7 +215,7 @@ module Kicked = struct let server _ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>= fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>= fun conn -> _assert (Kick.is_rejected conn) __LOC__ "" >>=? fun () -> return_unit @@ -233,7 +237,7 @@ module Simple_message = struct let server ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> _assert (MBytes.compare simple_msg2 msg = 0) __LOC__ "" >>=? fun () -> @@ -243,7 +247,7 @@ module Simple_message = struct let client ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg2 >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> _assert (MBytes.compare simple_msg msg = 0) __LOC__ "" >>=? fun () -> @@ -265,6 +269,7 @@ module Chunked_message = struct let server ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> P2p_socket.accept + ~canceler ~binary_chunks_size:21 auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> @@ -276,6 +281,7 @@ module Chunked_message = struct let client ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> P2p_socket.accept + ~canceler ~binary_chunks_size:21 auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg2 >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> @@ -297,7 +303,7 @@ module Oversized_message = struct let server ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> _assert (MBytes.compare simple_msg2 msg = 0) __LOC__ "" >>=? fun () -> @@ -307,7 +313,7 @@ module Oversized_message = struct let client ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.write_sync conn simple_msg2 >>=? fun () -> P2p_socket.read conn >>=? fun (_msg_size, msg) -> _assert (MBytes.compare simple_msg msg = 0) __LOC__ "" >>=? fun () -> @@ -327,14 +333,14 @@ module Close_on_read = struct let server ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> sync ch >>=? fun () -> P2p_socket.close conn >>= fun _stat -> return_unit let client ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> sync ch >>=? fun () -> P2p_socket.read conn >>= fun err -> _assert (is_connection_closed err) __LOC__ "" >>=? fun () -> @@ -353,14 +359,14 @@ module Close_on_write = struct let server ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.close conn >>= fun _stat -> sync ch >>=? fun ()-> return_unit let client ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> sync ch >>=? fun ()-> Lwt_unix.sleep 0.1 >>= fun () -> P2p_socket.write_sync conn simple_msg >>= fun err -> @@ -390,7 +396,7 @@ module Garbled_data = struct let server _ch sched socket = accept sched socket >>=? fun (_info, auth_fd) -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.raw_write_sync conn garbled_msg >>=? fun () -> P2p_socket.read conn >>= fun err -> _assert (is_connection_closed err) __LOC__ "" >>=? fun () -> @@ -399,7 +405,7 @@ module Garbled_data = struct let client _ch sched addr port = connect sched addr port id2 >>=? fun auth_fd -> - P2p_socket.accept auth_fd encoding >>=? fun conn -> + P2p_socket.accept ~canceler auth_fd encoding >>=? fun conn -> P2p_socket.read conn >>= fun err -> _assert (is_decoding_error err) __LOC__ "" >>=? fun () -> P2p_socket.close conn >>= fun _stat ->