Data_encoding: add check_size
This commit is contained in:
parent
5023e1a261
commit
2e9df07b0e
@ -16,6 +16,7 @@ type read_error =
|
||||
| Invalid_int of { min : int ; v : int ; max : int }
|
||||
| Invalid_float of { min : float ; v : float ; max : float }
|
||||
| Trailing_zero
|
||||
| Size_limit_exceeded
|
||||
|
||||
let pp_read_error ppf = function
|
||||
| Not_enough_data ->
|
||||
@ -34,6 +35,8 @@ let pp_read_error ppf = function
|
||||
Format.fprintf ppf "Invalid float (%f <= %f <= %f) " min v max
|
||||
| Trailing_zero ->
|
||||
Format.fprintf ppf "Trailing zero in Z"
|
||||
| Size_limit_exceeded ->
|
||||
Format.fprintf ppf "Size limit exceeded"
|
||||
|
||||
exception Read_error of read_error
|
||||
|
||||
|
@ -19,6 +19,7 @@ type read_error =
|
||||
| Invalid_int of { min : int ; v : int ; max : int }
|
||||
| Invalid_float of { min : float ; v : float ; max : float }
|
||||
| Trailing_zero
|
||||
| Size_limit_exceeded
|
||||
exception Read_error of read_error
|
||||
val pp_read_error: Format.formatter -> read_error -> unit
|
||||
|
||||
|
@ -110,6 +110,10 @@ let rec length : type x. x Encoding.t -> x -> int =
|
||||
| Splitted { encoding = e } -> length e value
|
||||
| Dynamic_size e ->
|
||||
Binary_size.int32 + length e value
|
||||
| Check_size { limit ; encoding = e } ->
|
||||
let length = length e value in
|
||||
if length > limit then raise (Write_error Size_limit_exceeded) ;
|
||||
length
|
||||
| Delayed f -> length (f ()) value
|
||||
|
||||
let fixed_length e =
|
||||
|
@ -15,19 +15,26 @@ type state = {
|
||||
buffer : MBytes.t ;
|
||||
mutable offset : int ;
|
||||
mutable remaining_bytes : int ;
|
||||
mutable allowed_bytes : int option ;
|
||||
}
|
||||
|
||||
let check_allowed_bytes state size =
|
||||
match state.allowed_bytes with
|
||||
| Some len when len < size -> raise Size_limit_exceeded
|
||||
| Some len -> Some (len - size)
|
||||
| None -> None
|
||||
|
||||
let check_remaining_bytes state size =
|
||||
if state.remaining_bytes < size then
|
||||
raise Not_enough_data ;
|
||||
state.remaining_bytes - size
|
||||
|
||||
let read_atom size conv state =
|
||||
let remaining_bytes = check_remaining_bytes state size in
|
||||
let res = conv state.buffer state.offset in
|
||||
let offset = state.offset in
|
||||
state.remaining_bytes <- check_remaining_bytes state size ;
|
||||
state.allowed_bytes <- check_allowed_bytes state size ;
|
||||
state.offset <- state.offset + size ;
|
||||
state.remaining_bytes <- remaining_bytes ;
|
||||
res
|
||||
conv state.buffer offset
|
||||
|
||||
(** Reader for all the atomic types. *)
|
||||
module Atom = struct
|
||||
@ -179,6 +186,7 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret
|
||||
Some (read_rec e state)
|
||||
| Objs (`Fixed sz, e1, e2) ->
|
||||
ignore (check_remaining_bytes state sz : int) ;
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
let left = read_rec e1 state in
|
||||
let right = read_rec e2 state in
|
||||
(left, right)
|
||||
@ -191,6 +199,7 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret
|
||||
| Tup e -> read_rec e state
|
||||
| Tups (`Fixed sz, e1, e2) ->
|
||||
ignore (check_remaining_bytes state sz : int) ;
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
let left = read_rec e1 state in
|
||||
let right = read_rec e2 state in
|
||||
(left, right)
|
||||
@ -219,10 +228,31 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret
|
||||
if sz < 0 then raise (Invalid_size sz) ;
|
||||
let remaining = check_remaining_bytes state sz in
|
||||
state.remaining_bytes <- sz ;
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
let v = read_rec e state in
|
||||
if state.remaining_bytes <> 0 then raise Extra_bytes ;
|
||||
state.remaining_bytes <- remaining ;
|
||||
v
|
||||
| Check_size { limit ; encoding = e } ->
|
||||
let old_allowed_bytes = state.allowed_bytes in
|
||||
let limit =
|
||||
match state.allowed_bytes with
|
||||
| None -> limit
|
||||
| Some current_limit -> min current_limit limit in
|
||||
state.allowed_bytes <- Some limit ;
|
||||
let v = read_rec e state in
|
||||
let allowed_bytes =
|
||||
match old_allowed_bytes with
|
||||
| None -> None
|
||||
| Some old_limit ->
|
||||
let remaining =
|
||||
match state.allowed_bytes with
|
||||
| None -> assert false
|
||||
| Some remaining -> remaining in
|
||||
let read = limit - remaining in
|
||||
Some (old_limit - read) in
|
||||
state.allowed_bytes <- allowed_bytes ;
|
||||
v
|
||||
| Describe { encoding = e } -> read_rec e state
|
||||
| Def { encoding = e } -> read_rec e state
|
||||
| Splitted { encoding = e } -> read_rec e state
|
||||
@ -267,7 +297,8 @@ and read_list : type a. a Encoding.t -> state -> a list
|
||||
|
||||
let read encoding buffer ofs len =
|
||||
let state =
|
||||
{ buffer ; offset = ofs ; remaining_bytes = len } in
|
||||
{ buffer ; offset = ofs ;
|
||||
remaining_bytes = len ; allowed_bytes = None } in
|
||||
match read_rec encoding state with
|
||||
| exception Read_error _ -> None
|
||||
| v -> Some (state.offset, v)
|
||||
@ -275,7 +306,8 @@ let read encoding buffer ofs len =
|
||||
let of_bytes_exn encoding buffer =
|
||||
let len = MBytes.length buffer in
|
||||
let state =
|
||||
{ buffer ; offset = 0 ; remaining_bytes = len } in
|
||||
{ buffer ; offset = 0 ;
|
||||
remaining_bytes = len ; allowed_bytes = None } in
|
||||
let v = read_rec encoding state in
|
||||
if state.offset <> len then raise Extra_bytes ;
|
||||
v
|
||||
|
@ -22,6 +22,10 @@ type state = {
|
||||
illimited). Reading less bytes should raise [Extra_bytes] and
|
||||
trying to read more bytes should raise [Not_enough_data]. *)
|
||||
|
||||
allowed_bytes : int option ;
|
||||
(** Maximum number of bytes that are allowed to be read from 'stream'
|
||||
before to fail (None = illimited). *)
|
||||
|
||||
total_read : int ;
|
||||
(** Total number of bytes that has been read from [stream] since the
|
||||
beginning. *)
|
||||
@ -41,6 +45,12 @@ let check_remaining_bytes state size =
|
||||
| Some len -> Some (len - size)
|
||||
| None -> None
|
||||
|
||||
let check_allowed_bytes state size =
|
||||
match state.allowed_bytes with
|
||||
| Some len when len < size -> raise Size_limit_exceeded
|
||||
| Some len -> Some (len - size)
|
||||
| None -> None
|
||||
|
||||
(** [read_atom resume size conv state k] reads [size] bytes from [state],
|
||||
pass it to [conv] to be decoded, and finally call the continuation [k]
|
||||
with the decoded value and the updated state.
|
||||
@ -61,9 +71,10 @@ let check_remaining_bytes state size =
|
||||
let read_atom resume size conv state k =
|
||||
match
|
||||
let remaining_bytes = check_remaining_bytes state size in
|
||||
let allowed_bytes = check_allowed_bytes state size in
|
||||
let res, stream = Binary_stream.read state.stream size in
|
||||
conv res.buffer res.ofs,
|
||||
{ remaining_bytes ; stream ;
|
||||
{ remaining_bytes ; allowed_bytes ; stream ;
|
||||
total_read = state.total_read + size }
|
||||
with
|
||||
| exception (Read_error error) -> Error error
|
||||
@ -242,6 +253,7 @@ let rec read_rec
|
||||
k (Some v, state)
|
||||
| Objs (`Fixed sz, e1, e2) ->
|
||||
ignore (check_remaining_bytes state sz : int option) ;
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
read_rec e1 state @@ fun (left, state) ->
|
||||
read_rec e2 state @@ fun (right, state) ->
|
||||
k ((left, right), state)
|
||||
@ -254,6 +266,7 @@ let rec read_rec
|
||||
| Tup e -> read_rec e state k
|
||||
| Tups (`Fixed sz, e1, e2) ->
|
||||
ignore (check_remaining_bytes state sz : int option) ;
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
read_rec e1 state @@ fun (left, state) ->
|
||||
read_rec e2 state @@ fun (right, state) ->
|
||||
k ((left, right), state)
|
||||
@ -288,11 +301,31 @@ let rec read_rec
|
||||
else
|
||||
let remaining = check_remaining_bytes state sz in
|
||||
let state = { state with remaining_bytes = Some sz } in
|
||||
ignore (check_allowed_bytes state sz : int option) ;
|
||||
read_rec e state @@ fun (v, state) ->
|
||||
if state.remaining_bytes <> Some 0 then
|
||||
Error Extra_bytes
|
||||
else
|
||||
k (v, { state with remaining_bytes = remaining })
|
||||
| Check_size { limit ; encoding = e } ->
|
||||
let old_allowed_bytes = state.allowed_bytes in
|
||||
let limit =
|
||||
match state.allowed_bytes with
|
||||
| None -> limit
|
||||
| Some current_limit -> min current_limit limit in
|
||||
let state = { state with allowed_bytes = Some limit } in
|
||||
read_rec e state @@ fun (v, state) ->
|
||||
let allowed_bytes =
|
||||
match old_allowed_bytes with
|
||||
| None -> None
|
||||
| Some old_limit ->
|
||||
let remaining =
|
||||
match state.allowed_bytes with
|
||||
| None -> assert false
|
||||
| Some remaining -> remaining in
|
||||
let read = limit - remaining in
|
||||
Some (old_limit - read) in
|
||||
k (v, { state with allowed_bytes })
|
||||
| Describe { encoding = e } -> read_rec e state k
|
||||
| Def { encoding = e } -> read_rec e state k
|
||||
| Splitted { encoding = e } -> read_rec e state k
|
||||
@ -362,6 +395,6 @@ let read_stream ?(init = Binary_stream.empty) encoding =
|
||||
invalid_arg "Data_encoding.Binary.read_stream: variable encoding"
|
||||
| `Dynamic | `Fixed _ ->
|
||||
(* No hardcoded read limit in a stream. *)
|
||||
let state = { remaining_bytes = None ;
|
||||
let state = { remaining_bytes = None ; allowed_bytes = None ;
|
||||
stream = init ; total_read = 0 } in
|
||||
read_rec encoding state success
|
||||
|
@ -268,6 +268,28 @@ let rec write_rec : type a. a Encoding.t -> state -> a -> unit =
|
||||
MBytes.set_int32
|
||||
state.buffer (initial_offset - Binary_size.int32)
|
||||
(Int32.of_int size)
|
||||
| Check_size { limit ; encoding = e } -> begin
|
||||
(* backup the current limit *)
|
||||
let old_limit = state.allowed_bytes in
|
||||
(* install the new limit (only if smaller than the current limit) *)
|
||||
let limit =
|
||||
match state.allowed_bytes with
|
||||
| None -> limit
|
||||
| Some old_limit -> min old_limit limit in
|
||||
state.allowed_bytes <- Some limit ;
|
||||
write_rec e state value ;
|
||||
(* restore the previous limit (minus the read bytes) *)
|
||||
match old_limit with
|
||||
| None ->
|
||||
state.allowed_bytes <- None
|
||||
| Some old_limit ->
|
||||
let remaining =
|
||||
match state.allowed_bytes with
|
||||
| None -> assert false
|
||||
| Some len -> len in
|
||||
let read = limit - remaining in
|
||||
state.allowed_bytes <- Some (old_limit - read)
|
||||
end
|
||||
| Describe { encoding = e } -> write_rec e state value
|
||||
| Def { encoding = e } -> write_rec e state value
|
||||
| Splitted { encoding = e } -> write_rec e state value
|
||||
|
@ -387,6 +387,12 @@ module Encoding: sig
|
||||
Usually used to fix errors from combining two encodings. *)
|
||||
val dynamic_size : 'a encoding -> 'a encoding
|
||||
|
||||
(** [check_size size encoding] ensures that the binary encoding
|
||||
of a value will not be allowed to exceed [size] bytes. The reader and
|
||||
and the writer fails otherwise. This function do not modify
|
||||
the JSON encoding. *)
|
||||
val check_size : int -> 'a encoding -> 'a encoding
|
||||
|
||||
(** Recompute the encoding definition each time it is used.
|
||||
Useful for dynamically updating the encoding of values of an extensible
|
||||
type via a global reference (e.g. exceptions). *)
|
||||
@ -538,6 +544,7 @@ module Binary: sig
|
||||
| Invalid_int of { min : int ; v : int ; max : int }
|
||||
| Invalid_float of { min : float ; v : float ; max : float }
|
||||
| Trailing_zero
|
||||
| Size_limit_exceeded
|
||||
exception Read_error of read_error
|
||||
val pp_read_error: Format.formatter -> read_error -> unit
|
||||
|
||||
|
@ -107,6 +107,7 @@ type 'a desc =
|
||||
json_encoding : 'a Json_encoding.encoding ;
|
||||
is_obj : bool ; is_tup : bool } -> 'a desc
|
||||
| Dynamic_size : 'a t -> 'a desc
|
||||
| Check_size : { limit : int ; encoding : 'a t } -> 'a desc
|
||||
| Delayed : (unit -> 'a t) -> 'a desc
|
||||
|
||||
and _ field =
|
||||
@ -170,6 +171,7 @@ let rec classify : type a. a t -> Kind.t = fun e ->
|
||||
| Def { encoding } -> classify encoding
|
||||
| Splitted { encoding } -> classify encoding
|
||||
| Dynamic_size _ -> `Dynamic
|
||||
| Check_size { encoding } -> classify encoding
|
||||
| Delayed f -> classify (f ())
|
||||
|
||||
let make ?json_encoding encoding = { encoding ; json_encoding }
|
||||
@ -200,6 +202,9 @@ end
|
||||
let dynamic_size e =
|
||||
make @@ Dynamic_size e
|
||||
|
||||
let check_size limit encoding =
|
||||
make @@ Check_size { limit ; encoding }
|
||||
|
||||
let delayed f =
|
||||
make @@ Delayed f
|
||||
|
||||
@ -495,6 +500,7 @@ let rec is_nullable: type t. t encoding -> bool = fun e ->
|
||||
| Def { encoding = e } -> is_nullable e
|
||||
| Splitted { json_encoding } -> Json_encoding.is_nullable json_encoding
|
||||
| Dynamic_size e -> is_nullable e
|
||||
| Check_size { encoding = e } -> is_nullable e
|
||||
| Delayed _ -> true
|
||||
|
||||
let option ty =
|
||||
|
@ -65,6 +65,7 @@ type 'a desc =
|
||||
json_encoding : 'a Json_encoding.encoding ;
|
||||
is_obj : bool ; is_tup : bool } -> 'a desc
|
||||
| Dynamic_size : 'a t -> 'a desc
|
||||
| Check_size : { limit : int ; encoding : 'a t } -> 'a desc
|
||||
| Delayed : (unit -> 'a t) -> 'a desc
|
||||
|
||||
and _ field =
|
||||
@ -121,6 +122,7 @@ module Variable : sig
|
||||
val list : 'a encoding -> 'a list encoding
|
||||
end
|
||||
val dynamic_size : 'a encoding -> 'a encoding
|
||||
val check_size : int -> 'a encoding -> 'a encoding
|
||||
val delayed : (unit -> 'a encoding) -> 'a encoding
|
||||
val req :
|
||||
?title:string -> ?description:string ->
|
||||
|
@ -199,6 +199,7 @@ let rec json : type a. a Encoding.desc -> a Json_encoding.encoding =
|
||||
| Union (_tag_size, _, cases) -> union (List.map case_json cases)
|
||||
| Splitted { json_encoding } -> json_encoding
|
||||
| Dynamic_size e -> get_json e
|
||||
| Check_size { encoding } -> get_json encoding
|
||||
| Delayed f -> get_json (f ())
|
||||
|
||||
and field_json
|
||||
|
@ -81,6 +81,22 @@ let stream ?(expected = fun _ -> true) read_encoding bytes () =
|
||||
Binary.pp_read_error error
|
||||
done
|
||||
|
||||
let minimal_stream ?(expected = fun _ -> true) expected_read read_encoding bytes () =
|
||||
let name = "minimal_stream" in
|
||||
match streamed_read read_encoding bytes with
|
||||
| Binary.Success _, _ ->
|
||||
Alcotest.failf "%s failed: expecting exception, got success." name
|
||||
| Binary.Await _, _ ->
|
||||
Alcotest.failf "%s failed: not enough data" name
|
||||
| Binary.Error error, count when expected (Binary.Read_error error) && count = expected_read ->
|
||||
()
|
||||
| Binary.Error error, count ->
|
||||
Alcotest.failf
|
||||
"@[<v 2>%s failed: read error after reading %d. @ %a@]"
|
||||
name count
|
||||
Binary.pp_read_error error
|
||||
|
||||
|
||||
let all ?expected name write_encoding read_encoding value =
|
||||
let json_value = Json.construct write_encoding value in
|
||||
let bson_value = Bson.construct write_encoding value in
|
||||
@ -121,6 +137,36 @@ let all_ranged_float minimum maximum =
|
||||
all (name ^ ".min") float encoding (minimum -. 1.) @
|
||||
all (name ^ ".max") float encoding (maximum +. 1.)
|
||||
|
||||
let test_bounded_string_list =
|
||||
let expected = function
|
||||
| Binary_error.Read_error Size_limit_exceeded -> true
|
||||
| _ -> false in
|
||||
let test name ~total ~elements v expected_read expected_read' =
|
||||
let bytes = Binary.to_bytes_exn (Variable.list string) v in
|
||||
let vbytes = Binary.to_bytes_exn (list string) v in
|
||||
[ "bounded_string_list." ^ name, `Quick,
|
||||
binary ~expected (bounded_list ~total ~elements string) bytes ;
|
||||
"bounded_string_list_stream." ^ name, `Quick,
|
||||
stream ~expected
|
||||
(dynamic_size (bounded_list ~total:total ~elements string)) vbytes ;
|
||||
"bounded_string_list_minimal_stream." ^ name, `Quick,
|
||||
minimal_stream ~expected expected_read
|
||||
(dynamic_size (bounded_list ~total:total ~elements string)) vbytes ;
|
||||
"bounded_string_list_minimal_stream." ^ name, `Quick,
|
||||
minimal_stream ~expected expected_read'
|
||||
(check_size (total + 4)
|
||||
(dynamic_size (Variable.list (check_size elements string)))) vbytes ;
|
||||
|
||||
] in
|
||||
test "a" ~total:0 ~elements:0 [""] 4 4 @
|
||||
test "b1" ~total:3 ~elements:4 [""] 4 4 @
|
||||
test "b2" ~total:4 ~elements:3 [""] 4 4 @
|
||||
test "c1" ~total:19 ~elements:4 ["";"";"";"";""] 20 4 @
|
||||
test "c2" ~total:20 ~elements:3 ["";"";"";"";""] 4 4 @
|
||||
test "d1" ~total:20 ~elements:5 ["";"";"";"";"a"] 24 4 @
|
||||
test "d2" ~total:21 ~elements:4 ["";"";"";"";"a"] 24 24 @
|
||||
test "e" ~total:30 ~elements:10 ["ab";"c";"def";"gh";"ijk"] 32 4
|
||||
|
||||
let tests =
|
||||
all_ranged_int 100 400 @
|
||||
all_ranged_int 19000 19253 @
|
||||
@ -134,6 +180,7 @@ let tests =
|
||||
all "unknown_case.B" ~expected:missing_case union_enc mini_union_enc (B "2") @
|
||||
all "unknown_case.E" ~expected:missing_case union_enc mini_union_enc E @
|
||||
all "enum.missing" ~expected:missing_enum enum_enc mini_enum_enc 4 @
|
||||
test_bounded_string_list @
|
||||
[ "z.truncated", `Quick,
|
||||
binary ~expected:not_enough_data z (MBytes.of_string "\x83") ;
|
||||
"z.trailing_zero", `Quick,
|
||||
|
@ -122,6 +122,18 @@ let test_string_enum_boundary () =
|
||||
run_test entries2 ;
|
||||
run_test (("256", 256) :: entries2)
|
||||
|
||||
let test_bounded_string_list =
|
||||
let test name ~total ~elements v =
|
||||
"bounded_string_list." ^ name, `Quick,
|
||||
binary Alcotest.(list string)
|
||||
(bounded_list ~total ~elements string) v in
|
||||
[ test "a" ~total:0 ~elements:0 [] ;
|
||||
test "b" ~total:4 ~elements:4 [""] ;
|
||||
test "c" ~total:20 ~elements:4 ["";"";"";"";""] ;
|
||||
test "d" ~total:21 ~elements:5 ["";"";"";"";"a"] ;
|
||||
test "e" ~total:31 ~elements:10 ["ab";"c";"def";"gh";"ijk"] ;
|
||||
]
|
||||
|
||||
let tests =
|
||||
all "null" Alcotest.pass null () @
|
||||
all "empty" Alcotest.pass empty () @
|
||||
@ -219,5 +231,6 @@ let tests =
|
||||
all "array" Alcotest.(array int) (array int31) [|1;2;3;4;5|] @
|
||||
all "mu_list.empty" Alcotest.(list int) (mu_list_enc int31) [] @
|
||||
all "mu_list" Alcotest.(list int) (mu_list_enc int31) [1;2;3;4;5] @
|
||||
test_bounded_string_list @
|
||||
[ "string_enum_boundary", `Quick, test_string_enum_boundary ;
|
||||
]
|
||||
|
@ -162,6 +162,9 @@ let mu_list_enc enc =
|
||||
(fun (x, xs) -> x :: xs) ;
|
||||
]
|
||||
|
||||
let bounded_list ~total ~elements enc =
|
||||
check_size total (Variable.list (check_size elements enc))
|
||||
|
||||
module Alcotest = struct
|
||||
include Alcotest
|
||||
let float =
|
||||
|
@ -51,6 +51,23 @@ let all_ranged_float minimum maximum =
|
||||
all (name ^ ".min") encoding (minimum -. 1.) @
|
||||
all (name ^ ".max") encoding (maximum +. 1.)
|
||||
|
||||
let test_bounded_string_list =
|
||||
let expected = function
|
||||
| Binary_error.Write_error Size_limit_exceeded -> true
|
||||
| _ -> false in
|
||||
let test name ~total ~elements v =
|
||||
"bounded_string_list." ^ name, `Quick,
|
||||
binary ~expected (bounded_list ~total ~elements string) v in
|
||||
[ test "a" ~total:0 ~elements:0 [""] ;
|
||||
test "b1" ~total:3 ~elements:4 [""] ;
|
||||
test "b2" ~total:4 ~elements:3 [""] ;
|
||||
test "c1" ~total:19 ~elements:4 ["";"";"";"";""] ;
|
||||
test "c2" ~total:20 ~elements:3 ["";"";"";"";""] ;
|
||||
test "d1" ~total:20 ~elements:5 ["";"";"";"";"a"] ;
|
||||
test "d2" ~total:21 ~elements:4 ["";"";"";"";"a"] ;
|
||||
test "e" ~total:30 ~elements:10 ["ab";"c";"def";"gh";"ijk"] ;
|
||||
]
|
||||
|
||||
let tests =
|
||||
all_ranged_int 100 400 @
|
||||
all_ranged_int 19000 19254 @
|
||||
@ -61,4 +78,5 @@ let tests =
|
||||
all "bytes.fixed" (Fixed.bytes 4) (MBytes.of_string "turlututu") @
|
||||
all "unknown_case.B" mini_union_enc (B "2") @
|
||||
all "unknown_case.E" mini_union_enc E @
|
||||
test_bounded_string_list @
|
||||
[]
|
||||
|
Loading…
Reference in New Issue
Block a user