diff --git a/src/node/net/p2p_connection.ml b/src/node/net/p2p_connection.ml index b00c7619b..b4e439f9d 100644 --- a/src/node/net/p2p_connection.ml +++ b/src/node/net/p2p_connection.ml @@ -292,19 +292,18 @@ module Writer = struct canceler: Canceler.t ; conn: connection ; encoding: 'msg Data_encoding.t ; - messages: ('msg * unit tzresult Lwt.u option) Lwt_pipe.t ; + messages: (MBytes.t * unit tzresult Lwt.u option) Lwt_pipe.t ; mutable worker: unit Lwt.t ; } let encode_message st msg = - try return (Data_encoding.Binary.to_bytes st.encoding msg) - with _ -> fail Encoding_error + try ok (Data_encoding.Binary.to_bytes st.encoding msg) + with _ -> error Encoding_error let rec worker_loop st = Lwt_unix.yield () >>= fun () -> Lwt_utils.protect ~canceler:st.canceler begin fun () -> - Lwt_pipe.pop st.messages >>= fun (msg, wakener) -> - encode_message st msg >>=? fun buf -> + Lwt_pipe.pop st.messages >>= fun (buf, wakener) -> lwt_debug "writing %d bytes to %a" (MBytes.length buf) Connection_info.pp st.conn.info >>= fun () -> Crypto.write_chunk st.conn.fd st.conn.cryptobox_data buf >>= fun res -> @@ -326,10 +325,8 @@ module Writer = struct let run ?size conn encoding canceler = let compute_size = function - | msg, None -> - 10 * (Sys.word_size / 8) + Data_encoding.Binary.length encoding msg - | msg, Some _ -> - 18 * (Sys.word_size / 8) + Data_encoding.Binary.length encoding msg + | buf, None -> Sys.word_size + MBytes.length buf + | buf, Some _ -> 2 * Sys.word_size + MBytes.length buf in let size = map_option size ~f:(fun max -> max, compute_size) in let st = @@ -403,18 +400,28 @@ let catch_closed_pipe f = let write { writer } msg = catch_closed_pipe begin fun () -> - Lwt_pipe.push writer.messages (msg, None) >>= return + Lwt.return (Writer.encode_message writer msg) >>=? fun buf -> + Lwt_pipe.push writer.messages (buf, None) >>= return end let write_sync { writer } msg = catch_closed_pipe begin fun () -> let waiter, wakener = Lwt.wait () in - Lwt_pipe.push writer.messages (msg, Some wakener) >>= fun () -> + Lwt.return (Writer.encode_message writer msg) >>=? fun buf -> + Lwt_pipe.push writer.messages (buf, Some wakener) >>= fun () -> waiter end let write_now { writer } msg = - try Ok (Lwt_pipe.push_now writer.messages (msg, None)) + Writer.encode_message writer msg >>? fun buf -> + try Ok (Lwt_pipe.push_now writer.messages (buf, None)) with Lwt_pipe.Closed -> Error [P2p_io_scheduler.Connection_closed] +let raw_write_sync { writer } bytes = + catch_closed_pipe begin fun () -> + let waiter, wakener = Lwt.wait () in + Lwt_pipe.push writer.messages (bytes, Some wakener) >>= fun () -> + waiter + end + let is_readable { reader } = not (Lwt_pipe.is_empty reader.messages) let wait_readable { reader } = diff --git a/src/node/net/p2p_connection.mli b/src/node/net/p2p_connection.mli index 60aa6dbf5..4ec769413 100644 --- a/src/node/net/p2p_connection.mli +++ b/src/node/net/p2p_connection.mli @@ -112,3 +112,8 @@ val stat: 'msg t -> Stat.t [conn]. *) val close: ?wait:bool -> 'msg t -> unit Lwt.t + +(**/**) + +(** for testing only *) +val raw_write_sync: 'msg t -> MBytes.t -> unit tzresult Lwt.t diff --git a/src/node/net/p2p_connection_pool.ml b/src/node/net/p2p_connection_pool.ml index 61bab5608..ad5abb600 100644 --- a/src/node/net/p2p_connection_pool.ml +++ b/src/node/net/p2p_connection_pool.ml @@ -509,6 +509,9 @@ let write { conn } msg = let write_sync { conn } msg = P2p_connection.write_sync conn (Message msg) +let raw_write_sync { conn } buf = + P2p_connection.raw_write_sync conn buf + let write_now { conn } msg = P2p_connection.write_now conn (Message msg) diff --git a/src/node/net/p2p_connection_pool.mli b/src/node/net/p2p_connection_pool.mli index 43d13fafb..ce645156d 100644 --- a/src/node/net/p2p_connection_pool.mli +++ b/src/node/net/p2p_connection_pool.mli @@ -253,6 +253,10 @@ val write_sync: ('msg, 'meta) connection -> 'msg -> unit tzresult Lwt.t (** [write_sync conn msg] is [P2p_connection.write_sync conn' msg] where [conn'] is the internal [P2p_connection.t] inside [conn]. *) +(**/**) +val raw_write_sync: ('msg, 'meta) connection -> MBytes.t -> unit tzresult Lwt.t +(**/**) + val write_now: ('msg, 'meta) connection -> 'msg -> bool tzresult (** [write_now conn msg] is [P2p_connection.write_now conn' msg] where [conn'] is the internal [P2p_connection.t] inside [conn]. *)