Shell/P2p: do not read the tag twice...
This commit is contained in:
parent
9d67c1fea5
commit
6121c518d4
@ -129,103 +129,94 @@ module Make (P: PARAMS) = struct
|
|||||||
transmission (and needs not being replied). The [Unkown] packet is
|
transmission (and needs not being replied). The [Unkown] packet is
|
||||||
not a real kind of packet, it means that something indecypherable
|
not a real kind of packet, it means that something indecypherable
|
||||||
was transmitted. *)
|
was transmitted. *)
|
||||||
type hello = {
|
|
||||||
gid: gid;
|
|
||||||
port: int option;
|
|
||||||
versions: version list;
|
|
||||||
}
|
|
||||||
|
|
||||||
let hello_encoding =
|
|
||||||
let open Data_encoding in
|
|
||||||
conv
|
|
||||||
(fun { gid; port; versions } -> (gid, port, versions))
|
|
||||||
(fun (gid, port, versions) -> { gid; port; versions })
|
|
||||||
(obj3
|
|
||||||
(req "gid" (Fixed.string 16)) (* TODO: get rid of constant *)
|
|
||||||
(opt "port" int16)
|
|
||||||
(req "versions" (Variable.list version_encoding)))
|
|
||||||
|
|
||||||
type msg =
|
type msg =
|
||||||
| Connect of hello
|
| Connect of {
|
||||||
|
gid : string ;
|
||||||
|
port : int option ;
|
||||||
|
versions : version list ;
|
||||||
|
}
|
||||||
| Disconnect
|
| Disconnect
|
||||||
| Advertise of point list
|
|
||||||
| Bootstrap
|
| Bootstrap
|
||||||
|
| Advertise of point list
|
||||||
| Message of P.msg
|
| Message of P.msg
|
||||||
|
|
||||||
let msg_encoding =
|
let msg_encoding =
|
||||||
let open Data_encoding in
|
let open Data_encoding in
|
||||||
union ~tag_size:`Uint8 begin [
|
union ~tag_size:`Uint16
|
||||||
case ~tag:0x00 hello_encoding
|
([ case ~tag:0x00
|
||||||
(function Connect hello -> Some hello | _ -> None)
|
(obj3
|
||||||
(fun hello -> Connect hello);
|
(req "gid" (Fixed.string gid_length))
|
||||||
|
(req "port" uint16)
|
||||||
|
(req "versions" (Variable.list version_encoding)))
|
||||||
|
(function
|
||||||
|
| Connect { gid ; port ; versions } ->
|
||||||
|
let port = match port with None -> 0 | Some port -> port in
|
||||||
|
Some (gid, port, versions)
|
||||||
|
| _ -> None)
|
||||||
|
(fun (gid, port, versions) ->
|
||||||
|
let port = if port = 0 then None else Some port in
|
||||||
|
Connect { gid ; port ; versions });
|
||||||
case ~tag:0x01 null
|
case ~tag:0x01 null
|
||||||
(function Disconnect -> Some () | _ -> None)
|
(function Disconnect -> Some () | _ -> None)
|
||||||
(fun () -> Disconnect);
|
(fun () -> Disconnect);
|
||||||
case ~tag:0x04 (Variable.list point_encoding)
|
case ~tag:0x02 null
|
||||||
(function Advertise points -> Some points | _ -> None)
|
|
||||||
(fun points -> Advertise points);
|
|
||||||
case ~tag:0x05 null
|
|
||||||
(function Bootstrap -> Some () | _ -> None)
|
(function Bootstrap -> Some () | _ -> None)
|
||||||
(fun () -> Bootstrap);
|
(fun () -> Bootstrap);
|
||||||
|
case ~tag:0x03 (Variable.list point_encoding)
|
||||||
|
(function Advertise points -> Some points | _ -> None)
|
||||||
|
(fun points -> Advertise points);
|
||||||
] @
|
] @
|
||||||
ListLabels.map P.encodings ~f:begin function Encoding { tag; encoding; wrap; unwrap } ->
|
ListLabels.map P.encodings
|
||||||
|
~f:(function Encoding { tag ; encoding ; wrap ; unwrap } ->
|
||||||
case ~tag encoding
|
case ~tag encoding
|
||||||
(function Message msg -> unwrap msg | _ -> None)
|
(function Message msg -> unwrap msg | _ -> None)
|
||||||
(fun msg -> Message (wrap msg))
|
(fun msg -> Message (wrap msg))))
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
let max_length = function
|
let hdrlen = 2
|
||||||
| 0 -> Some 1024
|
let maxlen = hdrlen + 2 lsl 16
|
||||||
| 1 -> Some 0
|
|
||||||
| 2 -> Some 0
|
|
||||||
| 3 -> Some 0
|
|
||||||
| 4 -> Some (1 + 1000 * 17) (* tag + 1000 * max (point size) *)
|
|
||||||
| 5 -> Some 0
|
|
||||||
| n -> ListLabels.fold_left P.encodings ~init:None ~f:begin fun a -> function
|
|
||||||
Encoding { tag; max_length } -> if tag = n then max_length else a
|
|
||||||
end
|
|
||||||
|
|
||||||
module BE = EndianBigstring.BigEndian
|
|
||||||
|
|
||||||
(* read a message from a TCP socket *)
|
(* read a message from a TCP socket *)
|
||||||
let recv_msg fd buf =
|
let recv_msg fd buf =
|
||||||
catch (fun () ->
|
catch
|
||||||
Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen ->
|
(fun () ->
|
||||||
if hdrlen <> 4 then begin
|
assert (MBytes.length buf >= 2 lsl 16) ;
|
||||||
debug "read: could not read enough bytes to determine message size, aborting";
|
Lwt_utils.read_mbytes ~len:hdrlen fd buf >>= fun () ->
|
||||||
return None
|
let len = EndianBigstring.BigEndian.get_uint16 buf 0 in
|
||||||
end
|
(* TODO timeout read ??? *)
|
||||||
else
|
Lwt_utils.read_mbytes ~len fd buf >>= fun () ->
|
||||||
Lwt_bytes.read fd buf 0 4 >>= fun _hdrlen ->
|
(* TODO conditionnaly decrypt payload... ?? *)
|
||||||
let len = Int32.to_int (BE.get_int32 buf 0) in
|
match Data_encoding.Binary.read msg_encoding buf 0 len with
|
||||||
if len < 0 || len > MBytes.length buf then begin
|
| None ->
|
||||||
debug "read: invalid message size %d" len;
|
(* TODO track invalid message *)
|
||||||
return None
|
return Disconnect
|
||||||
end
|
| Some (read, _) when read <> len ->
|
||||||
else
|
(* TODO track invalid message *)
|
||||||
Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () ->
|
return Disconnect
|
||||||
let tag = BE.get_uint8 buf 4 in
|
| Some (_, msg) ->
|
||||||
let msg = MBytes.sub buf 4 len in
|
Lwt.return msg)
|
||||||
match max_length tag with
|
|
||||||
| Some len when MBytes.length msg > len -> None
|
|
||||||
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg)
|
|
||||||
(function
|
(function
|
||||||
| Unix.Unix_error (_err, _, _) -> return None
|
| Unix.Unix_error _ -> return Disconnect
|
||||||
| e -> fail e)
|
| e -> fail e)
|
||||||
|
|
||||||
(* send a message over a TCP socket *)
|
(* send a message over a TCP socket *)
|
||||||
let send_msg fd buf packet =
|
let send_msg fd buf msg =
|
||||||
catch
|
catch
|
||||||
(fun () ->
|
(fun () ->
|
||||||
match Data_encoding.Binary.write msg_encoding packet buf 4 with
|
match Data_encoding.Binary.write msg_encoding msg buf hdrlen with
|
||||||
| None -> return_false
|
| None -> return_false
|
||||||
| Some len ->
|
| Some len ->
|
||||||
BE.set_int32 buf 0 @@ Int32.of_int (len - 4);
|
if len > maxlen then
|
||||||
Lwt_utils.write_mbytes fd buf ~len >>= fun () ->
|
return_false
|
||||||
return_true
|
else begin
|
||||||
)
|
EndianBigstring.BigEndian.set_int16 buf 0 (len - hdrlen) ;
|
||||||
(fun exn -> Lwt.fail exn)
|
(* TODO conditionnaly encrypt payload... ? *)
|
||||||
|
(* TODO timeout write ??? *)
|
||||||
|
Lwt_utils.write_mbytes ~len fd buf >>= fun () ->
|
||||||
|
return true
|
||||||
|
end)
|
||||||
|
(function
|
||||||
|
| Unix.Unix_error _ -> return_false
|
||||||
|
| e -> fail e)
|
||||||
|
|
||||||
(* A peer handle, as a record-encoded object, abstract from the
|
(* A peer handle, as a record-encoded object, abstract from the
|
||||||
outside world. A hidden Lwt worker is associated to a peer at its
|
outside world. A hidden Lwt worker is associated to a peer at its
|
||||||
@ -377,17 +368,18 @@ module Make (P: PARAMS) = struct
|
|||||||
let cancelation, cancel, on_cancel = canceler () in
|
let cancelation, cancel, on_cancel = canceler () in
|
||||||
(* a cancelable reception *)
|
(* a cancelable reception *)
|
||||||
let recv buf =
|
let recv buf =
|
||||||
pick [ (recv_msg socket buf >|= function Some p -> p | None -> Disconnect);
|
pick [ recv_msg socket buf ;
|
||||||
(cancelation () >>= fun () -> return Disconnect) ] in
|
(cancelation () >>= fun () -> return Disconnect) ] in
|
||||||
(* First step: send and receive credentials, makes no difference
|
(* First step: send and receive credentials, makes no difference
|
||||||
whether we're trying to connect to a peer or checking an incoming
|
whether we're trying to connect to a peer or checking an incoming
|
||||||
connection, both parties must first present themselves. *)
|
connection, both parties must first present themselves. *)
|
||||||
let rec connect buf =
|
let rec connect buf =
|
||||||
send_msg socket buf (Connect { gid = my_gid ;
|
send_msg socket buf
|
||||||
|
(Connect { gid = my_gid ;
|
||||||
port = config.incoming_port ;
|
port = config.incoming_port ;
|
||||||
versions = P.supported_versions }) >>= fun _ ->
|
versions = P.supported_versions }) >>= fun _ ->
|
||||||
pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ;
|
pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ;
|
||||||
recv buf ] >>= function
|
recv_msg socket buf ] >>= function
|
||||||
| Connect { gid; port = listening_port; versions } ->
|
| Connect { gid; port = listening_port; versions } ->
|
||||||
debug "(%a) connection requested from %a @ %a:%d"
|
debug "(%a) connection requested from %a @ %a:%d"
|
||||||
pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ;
|
pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ;
|
||||||
@ -458,7 +450,7 @@ module Make (P: PARAMS) = struct
|
|||||||
(* Launch the worker *)
|
(* Launch the worker *)
|
||||||
receiver ()
|
receiver ()
|
||||||
in
|
in
|
||||||
let buf = MBytes.create 0x100_000 in
|
let buf = MBytes.create maxlen in
|
||||||
on_cancel (fun () ->
|
on_cancel (fun () ->
|
||||||
send_msg socket buf Disconnect >>= fun _ ->
|
send_msg socket buf Disconnect >>= fun _ ->
|
||||||
LU.close socket >>= fun _ ->
|
LU.close socket >>= fun _ ->
|
||||||
@ -542,8 +534,7 @@ module Make (P: PARAMS) = struct
|
|||||||
|
|
||||||
(* A good random string so it is probably unique on the network *)
|
(* A good random string so it is probably unique on the network *)
|
||||||
let fresh_gid () =
|
let fresh_gid () =
|
||||||
Bytes.to_string @@ Sodium.Random.Bytes.generate 16
|
Bytes.to_string @@ Sodium.Random.Bytes.generate gid_length
|
||||||
|
|
||||||
|
|
||||||
(* The (fixed size) broadcast frame. *)
|
(* The (fixed size) broadcast frame. *)
|
||||||
let discovery_message_encoding =
|
let discovery_message_encoding =
|
||||||
|
Loading…
Reference in New Issue
Block a user