diff --git a/src/node/net/p2p.ml b/src/node/net/p2p.ml index 0f73bc0f9..da7cd29c5 100644 --- a/src/node/net/p2p.ml +++ b/src/node/net/p2p.ml @@ -129,103 +129,94 @@ module Make (P: PARAMS) = struct transmission (and needs not being replied). The [Unkown] packet is not a real kind of packet, it means that something indecypherable was transmitted. *) - type hello = { - gid: gid; - port: int option; - versions: version list; - } - - let hello_encoding = - let open Data_encoding in - conv - (fun { gid; port; versions } -> (gid, port, versions)) - (fun (gid, port, versions) -> { gid; port; versions }) - (obj3 - (req "gid" (Fixed.string 16)) (* TODO: get rid of constant *) - (opt "port" int16) - (req "versions" (Variable.list version_encoding))) - type msg = - | Connect of hello + | Connect of { + gid : string ; + port : int option ; + versions : version list ; + } | Disconnect - | Advertise of point list | Bootstrap + | Advertise of point list | Message of P.msg let msg_encoding = let open Data_encoding in - union ~tag_size:`Uint8 begin [ - case ~tag:0x00 hello_encoding - (function Connect hello -> Some hello | _ -> None) - (fun hello -> Connect hello); - case ~tag:0x01 null - (function Disconnect -> Some () | _ -> None) - (fun () -> Disconnect); - case ~tag:0x04 (Variable.list point_encoding) - (function Advertise points -> Some points | _ -> None) - (fun points -> Advertise points); - case ~tag:0x05 null - (function Bootstrap -> Some () | _ -> None) - (fun () -> Bootstrap); - ] @ - ListLabels.map P.encodings ~f:begin function Encoding { tag; encoding; wrap; unwrap } -> - case ~tag encoding - (function Message msg -> unwrap msg | _ -> None) - (fun msg -> Message (wrap msg)) - end - end + union ~tag_size:`Uint16 + ([ case ~tag:0x00 + (obj3 + (req "gid" (Fixed.string gid_length)) + (req "port" uint16) + (req "versions" (Variable.list version_encoding))) + (function + | Connect { gid ; port ; versions } -> + let port = match port with None -> 0 | Some port -> port in + Some (gid, port, versions) + | _ -> None) + (fun (gid, port, versions) -> + let port = if port = 0 then None else Some port in + Connect { gid ; port ; versions }); + case ~tag:0x01 null + (function Disconnect -> Some () | _ -> None) + (fun () -> Disconnect); + case ~tag:0x02 null + (function Bootstrap -> Some () | _ -> None) + (fun () -> Bootstrap); + case ~tag:0x03 (Variable.list point_encoding) + (function Advertise points -> Some points | _ -> None) + (fun points -> Advertise points); + ] @ + ListLabels.map P.encodings + ~f:(function Encoding { tag ; encoding ; wrap ; unwrap } -> + case ~tag encoding + (function Message msg -> unwrap msg | _ -> None) + (fun msg -> Message (wrap msg)))) - let max_length = function - | 0 -> Some 1024 - | 1 -> Some 0 - | 2 -> Some 0 - | 3 -> Some 0 - | 4 -> Some (1 + 1000 * 17) (* tag + 1000 * max (point size) *) - | 5 -> Some 0 - | n -> ListLabels.fold_left P.encodings ~init:None ~f:begin fun a -> function - Encoding { tag; max_length } -> if tag = n then max_length else a - end - - module BE = EndianBigstring.BigEndian + let hdrlen = 2 + let maxlen = hdrlen + 2 lsl 16 (* read a message from a TCP socket *) let recv_msg fd buf = - catch (fun () -> - Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen -> - if hdrlen <> 4 then begin - debug "read: could not read enough bytes to determine message size, aborting"; - return None - end - else - Lwt_bytes.read fd buf 0 4 >>= fun _hdrlen -> - let len = Int32.to_int (BE.get_int32 buf 0) in - if len < 0 || len > MBytes.length buf then begin - debug "read: invalid message size %d" len; - return None - end - else - Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () -> - let tag = BE.get_uint8 buf 4 in - let msg = MBytes.sub buf 4 len in - match max_length tag with - | Some len when MBytes.length msg > len -> None - | _ -> Data_encoding.Binary.of_bytes msg_encoding msg) + catch + (fun () -> + assert (MBytes.length buf >= 2 lsl 16) ; + Lwt_utils.read_mbytes ~len:hdrlen fd buf >>= fun () -> + let len = EndianBigstring.BigEndian.get_uint16 buf 0 in + (* TODO timeout read ??? *) + Lwt_utils.read_mbytes ~len fd buf >>= fun () -> + (* TODO conditionnaly decrypt payload... ?? *) + match Data_encoding.Binary.read msg_encoding buf 0 len with + | None -> + (* TODO track invalid message *) + return Disconnect + | Some (read, _) when read <> len -> + (* TODO track invalid message *) + return Disconnect + | Some (_, msg) -> + Lwt.return msg) (function - | Unix.Unix_error (_err, _, _) -> return None + | Unix.Unix_error _ -> return Disconnect | e -> fail e) (* send a message over a TCP socket *) - let send_msg fd buf packet = + let send_msg fd buf msg = catch (fun () -> - match Data_encoding.Binary.write msg_encoding packet buf 4 with + match Data_encoding.Binary.write msg_encoding msg buf hdrlen with | None -> return_false | Some len -> - BE.set_int32 buf 0 @@ Int32.of_int (len - 4); - Lwt_utils.write_mbytes fd buf ~len >>= fun () -> - return_true - ) - (fun exn -> Lwt.fail exn) + if len > maxlen then + return_false + else begin + EndianBigstring.BigEndian.set_int16 buf 0 (len - hdrlen) ; + (* TODO conditionnaly encrypt payload... ? *) + (* TODO timeout write ??? *) + Lwt_utils.write_mbytes ~len fd buf >>= fun () -> + return true + end) + (function + | Unix.Unix_error _ -> return_false + | e -> fail e) (* A peer handle, as a record-encoded object, abstract from the outside world. A hidden Lwt worker is associated to a peer at its @@ -377,17 +368,18 @@ module Make (P: PARAMS) = struct let cancelation, cancel, on_cancel = canceler () in (* a cancelable reception *) let recv buf = - pick [ (recv_msg socket buf >|= function Some p -> p | None -> Disconnect); + pick [ recv_msg socket buf ; (cancelation () >>= fun () -> return Disconnect) ] in (* First step: send and receive credentials, makes no difference whether we're trying to connect to a peer or checking an incoming connection, both parties must first present themselves. *) let rec connect buf = - send_msg socket buf (Connect { gid = my_gid ; - port = config.incoming_port ; - versions = P.supported_versions }) >>= fun _ -> + send_msg socket buf + (Connect { gid = my_gid ; + port = config.incoming_port ; + versions = P.supported_versions }) >>= fun _ -> pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ; - recv buf ] >>= function + recv_msg socket buf ] >>= function | Connect { gid; port = listening_port; versions } -> debug "(%a) connection requested from %a @ %a:%d" pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ; @@ -458,7 +450,7 @@ module Make (P: PARAMS) = struct (* Launch the worker *) receiver () in - let buf = MBytes.create 0x100_000 in + let buf = MBytes.create maxlen in on_cancel (fun () -> send_msg socket buf Disconnect >>= fun _ -> LU.close socket >>= fun _ -> @@ -542,8 +534,7 @@ module Make (P: PARAMS) = struct (* A good random string so it is probably unique on the network *) let fresh_gid () = - Bytes.to_string @@ Sodium.Random.Bytes.generate 16 - + Bytes.to_string @@ Sodium.Random.Bytes.generate gid_length (* The (fixed size) broadcast frame. *) let discovery_message_encoding =