diff --git a/src/minutils/utils.ml b/src/minutils/utils.ml index d083f5cb4..bf6f36792 100644 --- a/src/minutils/utils.ml +++ b/src/minutils/utils.ml @@ -59,6 +59,10 @@ let unopt x = function | None -> x | Some x -> x +let unopt_map ~f ~default = function + | None -> default + | Some x -> f x + let unopt_list l = let may_cons xs x = match x with None -> xs | Some x -> x :: xs in List.rev @@ List.fold_left may_cons [] l @@ -72,6 +76,13 @@ let filter_map f l = let may_cons xs x = match f x with None -> xs | Some x -> x :: xs in List.rev @@ List.fold_left may_cons [] l +let list_rev_sub l n = + ListLabels.fold_left l ~init:(n, []) ~f:begin fun (n, l) elt -> + if n <= 0 then (n, l) else (n - 1, elt :: l) + end |> snd + +let list_sub l n = list_rev_sub l n |> List.rev + let display_paragraph ppf description = Format.fprintf ppf "@[%a@]" (fun ppf words -> List.iter (Format.fprintf ppf "%s@ ") words) @@ -111,3 +122,68 @@ let write_file ?(bin=false) fn contents = (fun () -> close_out oc) let (<<) g f = fun a -> g (f a) + +let rec (--) i j = + let rec loop acc j = + if j < i then acc else loop (j :: acc) (pred j) in + loop [] j + +let take_n_unsorted n l = + let rec loop acc n = function + | [] -> l + | _ when n <= 0 -> List.rev acc + | x :: xs -> loop (x :: acc) (pred n) xs in + loop [] n l + +module Bounded(E: Set.OrderedType) = struct + + (* TODO one day replace list by an heap array *) + + type t = { + bound : int ; + mutable size : int ; + mutable data : E.t list ; + } + let create bound = { bound ; size = 0 ; data = [] } + + let rec push x = function + | [] -> [x] + | (y :: xs) as ys -> + let c = compare x y in + if c < 0 then x :: ys else if c = 0 then ys else y :: push x xs + + let replace x xs = + match xs with + | y :: xs when compare x y > 0 -> + push x xs + | xs -> xs + + let insert x t = + if t.size < t.bound then begin + t.size <- t.size + 1 ; + t.data <- push x t.data + end else if E.compare (List.hd t.data) x < 0 then + t.data <- replace x t.data + + let get { data } = data + +end + +let take_n_sorted (type a) compare n l = + let module B = Bounded(struct type t = a let compare = compare end) in + let t = B.create n in + List.iter (fun x -> B.insert x t) l ; + B.get t + +let take_n ?compare n l = + match compare with + | None -> take_n_unsorted n l + | Some compare -> take_n_sorted compare n l + +let select n l = + let rec loop n acc = function + | [] -> invalid_arg "Utils.select" + | x :: xs when n <= 0 -> x, List.rev_append acc xs + | x :: xs -> loop (pred n) (x :: acc) xs + in + loop n [] l diff --git a/src/minutils/utils.mli b/src/minutils/utils.mli index 1c5a3f00a..0b3ec0f00 100644 --- a/src/minutils/utils.mli +++ b/src/minutils/utils.mli @@ -22,6 +22,7 @@ val map_option: f:('a -> 'b) -> 'a option -> 'b option val apply_option: f:('a -> 'b option) -> 'a option -> 'b option val iter_option: f:('a -> unit) -> 'a option -> unit val unopt: 'a -> 'a option -> 'a +val unopt_map: f:('a -> 'b) -> default:'b -> 'a option -> 'b val unopt_list: 'a option list -> 'a list val first_some: 'a option -> 'a option -> 'a option @@ -34,6 +35,11 @@ val remove_prefix: prefix:string -> string -> string option val filter_map: ('a -> 'b option) -> 'a list -> 'b list +(** [list_rev_sub l n] is (List.rev l) capped to max n elements *) +val list_rev_sub : 'a list -> int -> 'a list +(** [list_sub l n] is l capped to max n elements *) +val list_sub: 'a list -> int -> 'a list + val finalize: (unit -> 'a) -> (unit -> unit) -> 'a val read_file: ?bin:bool -> string -> string @@ -41,3 +47,20 @@ val write_file: ?bin:bool -> string -> string -> unit (** Compose functions from right to left. *) val (<<) : ('b -> 'c) -> ('a -> 'b) -> 'a -> 'c + +(** Sequence: [i--j] is the sequence [i;i+1;...;j-1;j] *) +val (--) : int -> int -> int list + +(** [take_n n l] returns the [n] first elements of [n]. When [compare] + is provided, it returns the [n] greatest element of [l]. *) +val take_n: ?compare:('a -> 'a -> int) -> int -> 'a list -> 'a list + +(** Bounded sequence: keep only the [n] greatest elements. *) +module Bounded(E: Set.OrderedType) : sig + type t + val create: int -> t + val insert: E.t -> t -> unit + val get: t -> E.t list +end + +val select: int -> 'a list -> 'a * 'a list