diff --git a/src/node/net/p2p.ml b/src/node/net/p2p.ml index 84c71d639..982018524 100644 --- a/src/node/net/p2p.ml +++ b/src/node/net/p2p.ml @@ -196,19 +196,8 @@ module Make (P: PARAMS) = struct module BE = EndianBigstring.BigEndian - (** Read a message from a file descriptor and returns (tag, msg) *) - let read fd buf = - let rec read_into_exactly ?(pos=0) ?len descr buf = - let len = match len with None -> MBytes.length buf | Some l -> l in - let rec inner pos len = - if len = 0 then - Lwt.return_unit - else - Lwt_bytes.read descr buf pos len >>= fun nb_read -> - inner (pos + nb_read) (len - nb_read) - in - inner pos len - in + (* read a message from a TCP socket *) + let recv_msg fd buf = catch (fun () -> Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen -> if hdrlen <> 4 then begin @@ -223,42 +212,26 @@ module Make (P: PARAMS) = struct return None end else - read_into_exactly fd buf ~pos:4 ~len >|= fun () -> + Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () -> let tag = BE.get_uint8 buf 4 in - Some (tag, MBytes.sub buf 4 len)) + let msg = MBytes.sub buf 4 len in + match max_length tag with + | Some len when MBytes.length msg > len -> None + | _ -> Data_encoding.Binary.of_bytes msg_encoding msg) (function | Unix.Unix_error (_err, _, _) -> return None | e -> fail e) - (** Write a message to file descriptor. *) - let write ?(pos=0) ?len descr buf = - let len = match len with None -> MBytes.length buf | Some l -> l in - catch - (fun () -> - Lwt_bytes.write descr buf pos len >>= fun _nb_written -> - return true) - (function - | Unix.Unix_error _ -> return false - | e -> fail e) - - (* read a message from a TCP socket *) - let recv_msg fd buf = - read fd buf >|= function - | None -> None - | Some (tag, msg) -> - match max_length tag with - | Some len when MBytes.length msg > len -> None - | _ -> Data_encoding.Binary.of_bytes msg_encoding msg - (* send a message over a TCP socket *) let send_msg fd buf packet = catch (fun () -> match Data_encoding.Binary.write msg_encoding packet buf 4 with - | None -> return false + | None -> return_false | Some len -> BE.set_int32 buf 0 @@ Int32.of_int (len - 4); - write fd buf ~len + Lwt_utils.write_mbytes fd buf ~len >>= fun () -> + return_true ) (fun exn -> Lwt.fail exn)