Shell/P2p: do not read the tag twice...
This commit is contained in:
parent
9d67c1fea5
commit
6121c518d4
@ -129,103 +129,94 @@ module Make (P: PARAMS) = struct
|
||||
transmission (and needs not being replied). The [Unkown] packet is
|
||||
not a real kind of packet, it means that something indecypherable
|
||||
was transmitted. *)
|
||||
type hello = {
|
||||
gid: gid;
|
||||
port: int option;
|
||||
versions: version list;
|
||||
}
|
||||
|
||||
let hello_encoding =
|
||||
let open Data_encoding in
|
||||
conv
|
||||
(fun { gid; port; versions } -> (gid, port, versions))
|
||||
(fun (gid, port, versions) -> { gid; port; versions })
|
||||
(obj3
|
||||
(req "gid" (Fixed.string 16)) (* TODO: get rid of constant *)
|
||||
(opt "port" int16)
|
||||
(req "versions" (Variable.list version_encoding)))
|
||||
|
||||
type msg =
|
||||
| Connect of hello
|
||||
| Connect of {
|
||||
gid : string ;
|
||||
port : int option ;
|
||||
versions : version list ;
|
||||
}
|
||||
| Disconnect
|
||||
| Advertise of point list
|
||||
| Bootstrap
|
||||
| Advertise of point list
|
||||
| Message of P.msg
|
||||
|
||||
let msg_encoding =
|
||||
let open Data_encoding in
|
||||
union ~tag_size:`Uint8 begin [
|
||||
case ~tag:0x00 hello_encoding
|
||||
(function Connect hello -> Some hello | _ -> None)
|
||||
(fun hello -> Connect hello);
|
||||
case ~tag:0x01 null
|
||||
(function Disconnect -> Some () | _ -> None)
|
||||
(fun () -> Disconnect);
|
||||
case ~tag:0x04 (Variable.list point_encoding)
|
||||
(function Advertise points -> Some points | _ -> None)
|
||||
(fun points -> Advertise points);
|
||||
case ~tag:0x05 null
|
||||
(function Bootstrap -> Some () | _ -> None)
|
||||
(fun () -> Bootstrap);
|
||||
] @
|
||||
ListLabels.map P.encodings ~f:begin function Encoding { tag; encoding; wrap; unwrap } ->
|
||||
case ~tag encoding
|
||||
(function Message msg -> unwrap msg | _ -> None)
|
||||
(fun msg -> Message (wrap msg))
|
||||
end
|
||||
end
|
||||
union ~tag_size:`Uint16
|
||||
([ case ~tag:0x00
|
||||
(obj3
|
||||
(req "gid" (Fixed.string gid_length))
|
||||
(req "port" uint16)
|
||||
(req "versions" (Variable.list version_encoding)))
|
||||
(function
|
||||
| Connect { gid ; port ; versions } ->
|
||||
let port = match port with None -> 0 | Some port -> port in
|
||||
Some (gid, port, versions)
|
||||
| _ -> None)
|
||||
(fun (gid, port, versions) ->
|
||||
let port = if port = 0 then None else Some port in
|
||||
Connect { gid ; port ; versions });
|
||||
case ~tag:0x01 null
|
||||
(function Disconnect -> Some () | _ -> None)
|
||||
(fun () -> Disconnect);
|
||||
case ~tag:0x02 null
|
||||
(function Bootstrap -> Some () | _ -> None)
|
||||
(fun () -> Bootstrap);
|
||||
case ~tag:0x03 (Variable.list point_encoding)
|
||||
(function Advertise points -> Some points | _ -> None)
|
||||
(fun points -> Advertise points);
|
||||
] @
|
||||
ListLabels.map P.encodings
|
||||
~f:(function Encoding { tag ; encoding ; wrap ; unwrap } ->
|
||||
case ~tag encoding
|
||||
(function Message msg -> unwrap msg | _ -> None)
|
||||
(fun msg -> Message (wrap msg))))
|
||||
|
||||
let max_length = function
|
||||
| 0 -> Some 1024
|
||||
| 1 -> Some 0
|
||||
| 2 -> Some 0
|
||||
| 3 -> Some 0
|
||||
| 4 -> Some (1 + 1000 * 17) (* tag + 1000 * max (point size) *)
|
||||
| 5 -> Some 0
|
||||
| n -> ListLabels.fold_left P.encodings ~init:None ~f:begin fun a -> function
|
||||
Encoding { tag; max_length } -> if tag = n then max_length else a
|
||||
end
|
||||
|
||||
module BE = EndianBigstring.BigEndian
|
||||
let hdrlen = 2
|
||||
let maxlen = hdrlen + 2 lsl 16
|
||||
|
||||
(* read a message from a TCP socket *)
|
||||
let recv_msg fd buf =
|
||||
catch (fun () ->
|
||||
Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen ->
|
||||
if hdrlen <> 4 then begin
|
||||
debug "read: could not read enough bytes to determine message size, aborting";
|
||||
return None
|
||||
end
|
||||
else
|
||||
Lwt_bytes.read fd buf 0 4 >>= fun _hdrlen ->
|
||||
let len = Int32.to_int (BE.get_int32 buf 0) in
|
||||
if len < 0 || len > MBytes.length buf then begin
|
||||
debug "read: invalid message size %d" len;
|
||||
return None
|
||||
end
|
||||
else
|
||||
Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () ->
|
||||
let tag = BE.get_uint8 buf 4 in
|
||||
let msg = MBytes.sub buf 4 len in
|
||||
match max_length tag with
|
||||
| Some len when MBytes.length msg > len -> None
|
||||
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg)
|
||||
catch
|
||||
(fun () ->
|
||||
assert (MBytes.length buf >= 2 lsl 16) ;
|
||||
Lwt_utils.read_mbytes ~len:hdrlen fd buf >>= fun () ->
|
||||
let len = EndianBigstring.BigEndian.get_uint16 buf 0 in
|
||||
(* TODO timeout read ??? *)
|
||||
Lwt_utils.read_mbytes ~len fd buf >>= fun () ->
|
||||
(* TODO conditionnaly decrypt payload... ?? *)
|
||||
match Data_encoding.Binary.read msg_encoding buf 0 len with
|
||||
| None ->
|
||||
(* TODO track invalid message *)
|
||||
return Disconnect
|
||||
| Some (read, _) when read <> len ->
|
||||
(* TODO track invalid message *)
|
||||
return Disconnect
|
||||
| Some (_, msg) ->
|
||||
Lwt.return msg)
|
||||
(function
|
||||
| Unix.Unix_error (_err, _, _) -> return None
|
||||
| Unix.Unix_error _ -> return Disconnect
|
||||
| e -> fail e)
|
||||
|
||||
(* send a message over a TCP socket *)
|
||||
let send_msg fd buf packet =
|
||||
let send_msg fd buf msg =
|
||||
catch
|
||||
(fun () ->
|
||||
match Data_encoding.Binary.write msg_encoding packet buf 4 with
|
||||
match Data_encoding.Binary.write msg_encoding msg buf hdrlen with
|
||||
| None -> return_false
|
||||
| Some len ->
|
||||
BE.set_int32 buf 0 @@ Int32.of_int (len - 4);
|
||||
Lwt_utils.write_mbytes fd buf ~len >>= fun () ->
|
||||
return_true
|
||||
)
|
||||
(fun exn -> Lwt.fail exn)
|
||||
if len > maxlen then
|
||||
return_false
|
||||
else begin
|
||||
EndianBigstring.BigEndian.set_int16 buf 0 (len - hdrlen) ;
|
||||
(* TODO conditionnaly encrypt payload... ? *)
|
||||
(* TODO timeout write ??? *)
|
||||
Lwt_utils.write_mbytes ~len fd buf >>= fun () ->
|
||||
return true
|
||||
end)
|
||||
(function
|
||||
| Unix.Unix_error _ -> return_false
|
||||
| e -> fail e)
|
||||
|
||||
(* A peer handle, as a record-encoded object, abstract from the
|
||||
outside world. A hidden Lwt worker is associated to a peer at its
|
||||
@ -377,17 +368,18 @@ module Make (P: PARAMS) = struct
|
||||
let cancelation, cancel, on_cancel = canceler () in
|
||||
(* a cancelable reception *)
|
||||
let recv buf =
|
||||
pick [ (recv_msg socket buf >|= function Some p -> p | None -> Disconnect);
|
||||
pick [ recv_msg socket buf ;
|
||||
(cancelation () >>= fun () -> return Disconnect) ] in
|
||||
(* First step: send and receive credentials, makes no difference
|
||||
whether we're trying to connect to a peer or checking an incoming
|
||||
connection, both parties must first present themselves. *)
|
||||
let rec connect buf =
|
||||
send_msg socket buf (Connect { gid = my_gid ;
|
||||
port = config.incoming_port ;
|
||||
versions = P.supported_versions }) >>= fun _ ->
|
||||
send_msg socket buf
|
||||
(Connect { gid = my_gid ;
|
||||
port = config.incoming_port ;
|
||||
versions = P.supported_versions }) >>= fun _ ->
|
||||
pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ;
|
||||
recv buf ] >>= function
|
||||
recv_msg socket buf ] >>= function
|
||||
| Connect { gid; port = listening_port; versions } ->
|
||||
debug "(%a) connection requested from %a @ %a:%d"
|
||||
pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ;
|
||||
@ -458,7 +450,7 @@ module Make (P: PARAMS) = struct
|
||||
(* Launch the worker *)
|
||||
receiver ()
|
||||
in
|
||||
let buf = MBytes.create 0x100_000 in
|
||||
let buf = MBytes.create maxlen in
|
||||
on_cancel (fun () ->
|
||||
send_msg socket buf Disconnect >>= fun _ ->
|
||||
LU.close socket >>= fun _ ->
|
||||
@ -542,8 +534,7 @@ module Make (P: PARAMS) = struct
|
||||
|
||||
(* A good random string so it is probably unique on the network *)
|
||||
let fresh_gid () =
|
||||
Bytes.to_string @@ Sodium.Random.Bytes.generate 16
|
||||
|
||||
Bytes.to_string @@ Sodium.Random.Bytes.generate gid_length
|
||||
|
||||
(* The (fixed size) broadcast frame. *)
|
||||
let discovery_message_encoding =
|
||||
|
Loading…
Reference in New Issue
Block a user